mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
jagrit's commit files
This commit is contained in:
parent
d1f86272a2
commit
e6306cfee9
63
README.md
63
README.md
@ -1,2 +1,61 @@
|
|||||||
# mlx
|
# MLX
|
||||||
MLX: An array framework for Apple silicon
|
|
||||||
|
MLX is an array framework for machine learning specifically targeting Apple
|
||||||
|
Silicon. MLX is designed with inspiration from Jax, PyTorch, ArrayFire.
|
||||||
|
|
||||||
|
[Documentation](https://at.apple.com/mlx)
|
||||||
|
|
||||||
|
## Build
|
||||||
|
|
||||||
|
```
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. && make -j
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the C++ tests with `make test` (or `./tests/tests` for more detailed output).
|
||||||
|
|
||||||
|
### Python bidings
|
||||||
|
|
||||||
|
To install run:
|
||||||
|
|
||||||
|
`
|
||||||
|
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||||
|
`
|
||||||
|
|
||||||
|
For developing use an editable install:
|
||||||
|
|
||||||
|
```
|
||||||
|
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
To make sure the install is working run the tests with:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m unittest discover python/tests
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Develop
|
||||||
|
|
||||||
|
- Fork and submit pull requests to the repo.
|
||||||
|
|
||||||
|
- Every PR should have passing tests and at least one review.
|
||||||
|
|
||||||
|
- If a change is likely to impact efficiency, run some of the benchmarks before
|
||||||
|
and after the change. Examples of benchmarks can be found in `benchmarks/cpp/`.
|
||||||
|
|
||||||
|
- Install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||||
|
This should install hooks for running `black` and `clang-format` to ensure
|
||||||
|
consistent style for C++ and python code.
|
||||||
|
|
||||||
|
You can also run the formatters manually as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
clang-format -i file.cpp
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
black file.py
|
||||||
|
```
|
||||||
|
|
||||||
|
or run `pre-commit run --all-files` to check all files in the repo.
|
||||||
|
11
benchmarks/cpp/CMakeLists.txt
Normal file
11
benchmarks/cpp/CMakeLists.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
function(build_benchmark SRCFILE)
|
||||||
|
get_filename_component(src_name ${SRCFILE} NAME_WE)
|
||||||
|
set(target "${src_name}")
|
||||||
|
add_executable(${target} ${SRCFILE})
|
||||||
|
target_link_libraries(${target} PRIVATE mlx)
|
||||||
|
endfunction(build_benchmark)
|
||||||
|
|
||||||
|
build_benchmark(single_ops.cpp)
|
||||||
|
build_benchmark(irregular_strides.cpp)
|
||||||
|
build_benchmark(compare_devices.cpp)
|
||||||
|
build_benchmark(autograd.cpp)
|
37
benchmarks/cpp/autograd.cpp
Normal file
37
benchmarks/cpp/autograd.cpp
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
#include "time_utils.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void time_value_and_grad() {
|
||||||
|
auto x = ones({200, 1000});
|
||||||
|
eval(x);
|
||||||
|
auto fn = [](array x) {
|
||||||
|
for (int i = 0; i < 20; ++i) {
|
||||||
|
x = log(exp(x));
|
||||||
|
}
|
||||||
|
return sum(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto grad_fn = grad(fn);
|
||||||
|
auto independent_value_and_grad = [&]() {
|
||||||
|
auto value = fn(x);
|
||||||
|
auto dfdx = grad_fn(x);
|
||||||
|
return std::vector<array>{value, dfdx};
|
||||||
|
};
|
||||||
|
TIME(independent_value_and_grad);
|
||||||
|
|
||||||
|
auto value_and_grad_fn = value_and_grad(fn);
|
||||||
|
auto combined_value_and_grad = [&]() {
|
||||||
|
auto [value, dfdx] = value_and_grad_fn(x);
|
||||||
|
return std::vector<array>{value, dfdx};
|
||||||
|
};
|
||||||
|
TIME(combined_value_and_grad);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||||
|
time_value_and_grad();
|
||||||
|
}
|
25
benchmarks/cpp/compare_devices.cpp
Normal file
25
benchmarks/cpp/compare_devices.cpp
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
#include "time_utils.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void time_add_op() {
|
||||||
|
std::vector<int> sizes(1, 1);
|
||||||
|
for (int i = 0; i < 9; ++i) {
|
||||||
|
sizes.push_back(10 * sizes.back());
|
||||||
|
}
|
||||||
|
set_default_device(Device::cpu);
|
||||||
|
for (auto size : sizes) {
|
||||||
|
auto a = random::uniform({size});
|
||||||
|
auto b = random::uniform({size});
|
||||||
|
eval(a, b);
|
||||||
|
std::cout << "Size " << size << std::endl;
|
||||||
|
TIMEM("cpu", add, a, b, Device::cpu);
|
||||||
|
TIMEM("gpu", add, a, b, Device::gpu);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
time_add_op();
|
||||||
|
}
|
38
benchmarks/numpy/single_ops.py
Normal file
38
benchmarks/numpy/single_ops.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def time_add():
|
||||||
|
a = np.ones((100, 100, 10), dtype=np.float32)
|
||||||
|
b = np.ones((100, 100, 10), dtype=np.float32)
|
||||||
|
time_fn(np.add, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_matmul():
|
||||||
|
a = np.random.rand(1000, 500).astype(np.float32)
|
||||||
|
b = np.random.rand(500, 1000).astype(np.float32)
|
||||||
|
time_fn(np.matmul, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_exp():
|
||||||
|
a = np.random.randn(1000, 100).astype(np.float32)
|
||||||
|
time_fn(np.exp, a)
|
||||||
|
|
||||||
|
|
||||||
|
def time_take():
|
||||||
|
a = np.random.rand(10000, 500)
|
||||||
|
ids = np.random.randint(0, 10000, (20, 10))
|
||||||
|
ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
|
||||||
|
|
||||||
|
def random_take():
|
||||||
|
return [np.take(a, idx, 0) for idx in ids]
|
||||||
|
|
||||||
|
time_fn(random_take)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_add()
|
||||||
|
time_matmul()
|
||||||
|
time_exp()
|
||||||
|
time_take()
|
18
benchmarks/numpy/time_utils.py
Normal file
18
benchmarks/numpy/time_utils.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def time_fn(fn, *args):
|
||||||
|
print(f"Timing {fn.__name__} ...", end=" ")
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(5):
|
||||||
|
fn(*args)
|
||||||
|
|
||||||
|
num_iters = 100
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
x = fn(*args)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
|
||||||
|
msec = 1e3 * (toc - tic) / num_iters
|
||||||
|
print(f"{msec:.5f} msec")
|
190
benchmarks/python/blas/bench_gemm.py
Normal file
190
benchmarks/python/blas/bench_gemm.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import mlx.core as mx
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
N_warmup = 8
|
||||||
|
N_iter_bench = 80
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_nn_mlx(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = a @ b
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_nt_mlx(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = a @ b.transpose((0, 2, 1))
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_tn_mlx(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = a.transpose((0, 2, 1)) @ b
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_tt_mlx(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1))
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gemm_nn_torch(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = a @ b
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gemm_nt_torch(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = a @ b.transpose(-1, -2)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gemm_tn_torch(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = a.transpose(-1, -2) @ b
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gemm_tt_torch(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = a.transpose(-1, -2) @ b.transpose(-1, -2)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
||||||
|
shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M)
|
||||||
|
shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K)
|
||||||
|
|
||||||
|
a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype)
|
||||||
|
b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = {
|
||||||
|
"nn": gemm_nn_mlx,
|
||||||
|
"nt": gemm_nt_mlx,
|
||||||
|
"tn": gemm_tn_mlx,
|
||||||
|
"tt": gemm_tt_mlx,
|
||||||
|
}[transpose]
|
||||||
|
|
||||||
|
f_pt = {
|
||||||
|
"nn": gemm_nn_torch,
|
||||||
|
"nt": gemm_nt_torch,
|
||||||
|
"tn": gemm_tn_torch,
|
||||||
|
"tt": gemm_tt_torch,
|
||||||
|
}[transpose]
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1)
|
||||||
|
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||||
|
|
||||||
|
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||||
|
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
||||||
|
np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_gflop_count(B, M, N, K):
|
||||||
|
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32", "float16")
|
||||||
|
transposes = ("nn", "nt", "tn")
|
||||||
|
shapes = (
|
||||||
|
(16, 1024, 1024, 1024),
|
||||||
|
(1, 1024, 1024, 2048),
|
||||||
|
(4, 1024, 1024, 4096),
|
||||||
|
(4, 1024, 4096, 1024),
|
||||||
|
(1, 4096, 4096, 4096),
|
||||||
|
(15, 1023, 1023, 1023),
|
||||||
|
(17, 1025, 1025, 1025),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
for transpose in transposes:
|
||||||
|
for B, M, N, K in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)
|
||||||
|
|
||||||
|
gflop_count = get_gflop_count(B, M, N, K)
|
||||||
|
gflops_mx = gflop_count / (time_mlx)
|
||||||
|
gflops_pt = gflop_count / (time_torch)
|
||||||
|
diff = gflops_mx / gflops_pt - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if gflops_pt >= 2.0 * gflops_mx:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
219
benchmarks/python/blas/bench_gemv.py
Normal file
219
benchmarks/python/blas/bench_gemv.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import mlx.core as mx
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
results_dir = "./results"
|
||||||
|
|
||||||
|
if not os.path.isdir(results_dir):
|
||||||
|
os.mkdir(results_dir)
|
||||||
|
|
||||||
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
N_warmup = 5
|
||||||
|
N_iter_bench = 50
|
||||||
|
N_iter_func = 20
|
||||||
|
|
||||||
|
out_vec_sizes = [128, 512, 2048, 4096]
|
||||||
|
in_vec_sizes = [128, 512, 2048, 4096]
|
||||||
|
|
||||||
|
benchmark_vector_lens = []
|
||||||
|
benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2]
|
||||||
|
benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2]
|
||||||
|
benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2]
|
||||||
|
benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000]
|
||||||
|
|
||||||
|
benchmark_vector_lens.sort()
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, m, v):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(m, v)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(m, v)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def gemv_mlx(m, v):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = m @ v
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
def gemv_t_mlx(m, v):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = v @ m
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gemv_torch(m, v):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = m @ v
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gemv_t_torch(m, v):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = v @ m
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):
|
||||||
|
shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len)
|
||||||
|
shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1)
|
||||||
|
|
||||||
|
mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype)
|
||||||
|
vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype)
|
||||||
|
mat_mlx = mx.array(mat_npy)
|
||||||
|
vec_mlx = mx.array(vec_npy)
|
||||||
|
mat_trc = torch.from_numpy(mat_npy).to("mps")
|
||||||
|
vec_trc = torch.from_numpy(vec_npy).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
time_torch = (
|
||||||
|
bench(gemv_t_torch, mat_trc, vec_trc)
|
||||||
|
if transpose
|
||||||
|
else bench(gemv_torch, mat_trc, vec_trc)
|
||||||
|
)
|
||||||
|
time_mlx = (
|
||||||
|
bench(gemv_t_mlx, mat_mlx, vec_mlx)
|
||||||
|
if transpose
|
||||||
|
else bench(gemv_mlx, mat_mlx, vec_mlx)
|
||||||
|
)
|
||||||
|
|
||||||
|
c_mlx = (
|
||||||
|
np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx)
|
||||||
|
)
|
||||||
|
c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy)
|
||||||
|
|
||||||
|
if not np.allclose(c_mlx, c_npy, atol=2e-5):
|
||||||
|
print(
|
||||||
|
f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_gflop_count(in_vec_len, out_vec_len):
|
||||||
|
return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float(
|
||||||
|
1024**3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
|
||||||
|
n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len
|
||||||
|
item_size = 4 if np_dtype == np.float32 else 2
|
||||||
|
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
mlx_gb_s = []
|
||||||
|
mlx_gflops = []
|
||||||
|
pyt_gb_s = []
|
||||||
|
pyt_gflops = []
|
||||||
|
|
||||||
|
for out_vec_len in out_vector_lens:
|
||||||
|
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
|
||||||
|
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
|
||||||
|
|
||||||
|
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
|
||||||
|
|
||||||
|
mlx_gb_s.append(gbyte_size / time_mlx)
|
||||||
|
pyt_gb_s.append(gbyte_size / time_torch)
|
||||||
|
|
||||||
|
mlx_gflops.append(gflop_count / time_mlx)
|
||||||
|
pyt_gflops.append(gflop_count / time_torch)
|
||||||
|
|
||||||
|
if transpose:
|
||||||
|
title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}"
|
||||||
|
else:
|
||||||
|
title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}"
|
||||||
|
|
||||||
|
ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
|
||||||
|
ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch")
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)")
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
|
||||||
|
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
mlx_gb_s = []
|
||||||
|
mlx_gflops = []
|
||||||
|
pyt_gb_s = []
|
||||||
|
pyt_gflops = []
|
||||||
|
|
||||||
|
for in_vec_len in in_vector_lens:
|
||||||
|
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
|
||||||
|
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
|
||||||
|
|
||||||
|
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
|
||||||
|
|
||||||
|
mlx_gb_s.append(gbyte_size / time_mlx)
|
||||||
|
pyt_gb_s.append(gbyte_size / time_torch)
|
||||||
|
|
||||||
|
mlx_gflops.append(gflop_count / time_mlx)
|
||||||
|
pyt_gflops.append(gflop_count / time_torch)
|
||||||
|
|
||||||
|
if transpose:
|
||||||
|
title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])"
|
||||||
|
else:
|
||||||
|
title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )"
|
||||||
|
|
||||||
|
ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
|
||||||
|
ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch")
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)")
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
|
||||||
|
for transpose in (False, True):
|
||||||
|
for dtype in ("float32", "float16"):
|
||||||
|
fig, axs = plt.subplots(
|
||||||
|
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, in_vec_len in enumerate(in_vec_sizes):
|
||||||
|
bench_with_in_len(
|
||||||
|
axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, out_vec_len in enumerate(out_vec_sizes):
|
||||||
|
bench_with_out_len(
|
||||||
|
axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose
|
||||||
|
)
|
||||||
|
|
||||||
|
op_name = "gemv_t" if transpose else "gemv"
|
||||||
|
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||||
|
fig.savefig(
|
||||||
|
os.path.join(
|
||||||
|
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
plt.close(fig)
|
116
benchmarks/python/llama_mlx_bench.py
Normal file
116
benchmarks/python/llama_mlx_bench.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.utils
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaAttention(nn.Module):
|
||||||
|
def __init__(self, dims: int, num_heads: int):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.rope = nn.RoPE(dims // num_heads, True)
|
||||||
|
self.query_proj = nn.Linear(dims, dims, False)
|
||||||
|
self.key_proj = nn.Linear(dims, dims, False)
|
||||||
|
self.value_proj = nn.Linear(dims, dims, False)
|
||||||
|
self.out_proj = nn.Linear(dims, dims, False)
|
||||||
|
|
||||||
|
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||||
|
queries = self.query_proj(queries)
|
||||||
|
keys = self.key_proj(keys)
|
||||||
|
values = self.value_proj(values)
|
||||||
|
|
||||||
|
num_heads = self.num_heads
|
||||||
|
B, L, D = queries.shape
|
||||||
|
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||||
|
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||||
|
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
key_cache, value_cache = cache
|
||||||
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||||
|
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||||
|
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||||
|
values = mx.concatenate([value_cache, values], axis=2)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
|
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
|
||||||
|
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
|
||||||
|
|
||||||
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attention = LlamaAttention(dims, num_heads)
|
||||||
|
|
||||||
|
self.norm1 = nn.RMSNorm(dims)
|
||||||
|
self.norm2 = nn.RMSNorm(dims)
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(dims, mlp_dims, False)
|
||||||
|
self.linear2 = nn.Linear(dims, mlp_dims, False)
|
||||||
|
self.linear3 = nn.Linear(mlp_dims, dims, False)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
y = self.norm1(x)
|
||||||
|
y, cache = self.attention(y, y, y, mask, cache)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.norm2(x)
|
||||||
|
a = self.linear1(y)
|
||||||
|
b = self.linear2(y)
|
||||||
|
y = a * mx.sigmoid(a) * b
|
||||||
|
y = self.linear3(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
|
def measure(model, x, cache):
|
||||||
|
for i in range(5):
|
||||||
|
y, c = model(x, mask=None, cache=cache)
|
||||||
|
mx.eval(y, c)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
rs = []
|
||||||
|
for i in range(5):
|
||||||
|
y, c = model(x, mask=None, cache=cache)
|
||||||
|
rs.append((y, c))
|
||||||
|
mx.eval(rs)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
return (end - start) * 1000 / 5
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
H = 32
|
||||||
|
D = 4096
|
||||||
|
F = 43 * 256
|
||||||
|
C = 1000
|
||||||
|
mx.set_default_device(mx.gpu)
|
||||||
|
dtype = mx.float16
|
||||||
|
|
||||||
|
layer = LlamaEncoderLayer(D, F, H)
|
||||||
|
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
|
||||||
|
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
|
||||||
|
x = mx.random.normal([1, 1, D], dtype=dtype)
|
||||||
|
cache = [
|
||||||
|
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||||
|
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||||
|
]
|
||||||
|
mx.eval(x, cache)
|
||||||
|
|
||||||
|
T = measure(layer, x, cache)
|
||||||
|
|
||||||
|
print("Time per layer per token:", T, "ms")
|
||||||
|
print("Lower bound total time per token:", T * 32, "ms")
|
56
cmake/extension.cmake
Normal file
56
cmake/extension.cmake
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
include(CMakeParseArguments)
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Build metal library
|
||||||
|
#
|
||||||
|
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
||||||
|
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# TARGET: Custom target to be added for the metal library
|
||||||
|
# TITLE: Name of the .metallib
|
||||||
|
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
||||||
|
# SOURCES: List of source files
|
||||||
|
# INCLUDE_DIRS: List of include dirs
|
||||||
|
# DEPS: List of depedency files (like headers)
|
||||||
|
#
|
||||||
|
macro(mlx_build_metallib)
|
||||||
|
# Parse args
|
||||||
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||||
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
|
cmake_parse_arguments(
|
||||||
|
MTLLIB
|
||||||
|
""
|
||||||
|
"${oneValueArgs}"
|
||||||
|
"${multiValueArgs}"
|
||||||
|
${ARGN}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set output
|
||||||
|
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||||
|
|
||||||
|
# Collect compile options
|
||||||
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||||
|
|
||||||
|
# Prepare metllib build command
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||||
|
COMMAND xcrun -sdk macosx metal
|
||||||
|
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||||
|
${MTLLIB_COMPILE_OPTIONS}
|
||||||
|
${MTLLIB_SOURCES}
|
||||||
|
-o ${MTLLIB_BUILD_TARGET}
|
||||||
|
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||||
|
COMMAND_EXPAND_LISTS
|
||||||
|
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
||||||
|
VERBATIM
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add metallib custom target
|
||||||
|
add_custom_target(
|
||||||
|
${MTLLIB_TARGET}
|
||||||
|
DEPENDS
|
||||||
|
${MTLLIB_BUILD_TARGET}
|
||||||
|
)
|
||||||
|
|
||||||
|
endmacro(mlx_build_metallib)
|
0
docs/.nojekyll
Normal file
0
docs/.nojekyll
Normal file
1
docs/index.html
Normal file
1
docs/index.html
Normal file
@ -0,0 +1 @@
|
|||||||
|
<meta http-equiv="refresh" content="0; url=./build/html/index.html" />
|
19
docs/src/_templates/nn-module-template.rst
Normal file
19
docs/src/_templates/nn-module-template.rst
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
{{ fullname | escape | underline}}
|
||||||
|
|
||||||
|
.. currentmodule:: {{ module }}
|
||||||
|
|
||||||
|
.. autoclass:: {{ objname }}
|
||||||
|
|
||||||
|
{#{% block methods %}
|
||||||
|
|
||||||
|
{% if methods %}
|
||||||
|
.. rubric:: {{ _('Methods') }}
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
{% for item in methods %}
|
||||||
|
{%- if item not in inherited_members and item != '__init__' %}
|
||||||
|
~{{ name }}.{{ item }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{% endif %}
|
||||||
|
{% endblock %}#}
|
6
docs/src/cpp/ops.rst
Normal file
6
docs/src/cpp/ops.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
.. _cpp_ops:
|
||||||
|
|
||||||
|
Operations
|
||||||
|
==========
|
||||||
|
|
||||||
|
|
948
docs/src/dev/extensions.rst
Normal file
948
docs/src/dev/extensions.rst
Normal file
@ -0,0 +1,948 @@
|
|||||||
|
Developer Documentation
|
||||||
|
=======================
|
||||||
|
|
||||||
|
MLX provides a open and flexible backend to which users may add operations
|
||||||
|
and specialized implementations without much hassle. While the library supplies
|
||||||
|
efficient operations that can be used and composed for any number of
|
||||||
|
applications, there may arise cases where new functionalities or highly
|
||||||
|
optimized implementations are needed. For such cases, you may design and
|
||||||
|
implement your own operations that link to and build on top of :mod:`mlx.core`.
|
||||||
|
We will introduce the inner-workings of MLX and go over a simple example to
|
||||||
|
learn the steps involved in adding new operations to MLX with your own CPU
|
||||||
|
and GPU implementations.
|
||||||
|
|
||||||
|
Introducing the Example
|
||||||
|
-----------------------
|
||||||
|
|
||||||
|
Let's say that you would like an operation that takes in two arrays,
|
||||||
|
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
|
||||||
|
respectively, and then adds them together to get the result
|
||||||
|
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||||
|
writing out a function as follows:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||||
|
return alpha * x + beta * y
|
||||||
|
|
||||||
|
This function performs that operation while leaving the implementations and
|
||||||
|
differentiation to MLX.
|
||||||
|
|
||||||
|
However, you work with vector math libraries often and realize that the
|
||||||
|
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
|
||||||
|
You would really like the part of your applications that does this operation
|
||||||
|
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||||
|
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
||||||
|
our assumptions on to you, let's also assume that you want to learn how add
|
||||||
|
your own implementation for the gradients of your new operation while going
|
||||||
|
over the ins-and-outs of the MLX framework.
|
||||||
|
|
||||||
|
Well, what a coincidence! You are in the right place. Over the course of this
|
||||||
|
example, we will learn:
|
||||||
|
|
||||||
|
* The structure of the MLX library from the frontend API to the backend implementations.
|
||||||
|
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
|
||||||
|
* How to implement your own GPU implementation using metal.
|
||||||
|
* How to add your own ``vjp`` and ``jvp``.
|
||||||
|
* How to build your implementations, link them to MLX, and bind them to python.
|
||||||
|
|
||||||
|
Operations and Primitives
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
In one sentence, operations in MLX build the computation graph, and primitives
|
||||||
|
provide the rules for evaluation and transformations of said graph. Let's start
|
||||||
|
by discussing operations in more detail.
|
||||||
|
|
||||||
|
Operations
|
||||||
|
^^^^^^^^^^^
|
||||||
|
|
||||||
|
Operations are the frontend functions that operate on arrays. They are defined
|
||||||
|
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
|
||||||
|
operations in the Python API (:ref:`ops`).
|
||||||
|
|
||||||
|
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
|
||||||
|
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
|
||||||
|
C++ API:
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scale and sum two vectors elementwise
|
||||||
|
* z = alpha * x + beta * y
|
||||||
|
*
|
||||||
|
* Follow numpy style broadcasting between x and y
|
||||||
|
* Inputs are upcasted to floats if needed
|
||||||
|
**/
|
||||||
|
array axpby(
|
||||||
|
const array& x, // Input array x
|
||||||
|
const array& y, // Input array y
|
||||||
|
const float alpha, // Scaling factor for x
|
||||||
|
const float beta, // Scaling factor for y
|
||||||
|
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
This operation itself can call other operations within it if needed. So, the
|
||||||
|
simplest way to go about implementing this operation would be do so in terms
|
||||||
|
of existing operations.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
array axpby(
|
||||||
|
const array& x, // Input array x
|
||||||
|
const array& y, // Input array y
|
||||||
|
const float alpha, // Scaling factor for x
|
||||||
|
const float beta, // Scaling factor for y
|
||||||
|
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||||
|
) {
|
||||||
|
// Scale x and y on the provided stream
|
||||||
|
auto ax = multiply(array(alpha), x, s);
|
||||||
|
auto by = multiply(array(beta), y, s);
|
||||||
|
|
||||||
|
// Add and return
|
||||||
|
return add(ax, by, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
However, as we discussed earlier, this is not our goal. The operations themselves
|
||||||
|
do not contain the implementations that act on the data, nor do they contain the
|
||||||
|
rules of transformations. Rather, they are an easy to use interface that build
|
||||||
|
on top of the building blocks we call :class:`Primitive`.
|
||||||
|
|
||||||
|
Primitives
|
||||||
|
^^^^^^^^^^^
|
||||||
|
|
||||||
|
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||||
|
defines how to create an output given a set of input :class:`array` . Further,
|
||||||
|
a :class:`Primitive` is a class that contains rules on how it is evaluated
|
||||||
|
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
|
||||||
|
``jvp``. These words on their own can be a bit abstract, so lets take a step
|
||||||
|
back and go to our example to give ourselves a more concrete image.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
class Axpby : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit Axpby(Stream stream, float alpha, float beta)
|
||||||
|
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||||
|
* for the given inputs and populate the output array.
|
||||||
|
*
|
||||||
|
* To avoid unecessary allocations, the evaluation function
|
||||||
|
* is responsible for allocating space for the array.
|
||||||
|
*/
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
/** The Jacobian-vector product. */
|
||||||
|
array jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) override;
|
||||||
|
|
||||||
|
/** The vector-Jacobian product. */
|
||||||
|
std::vector<array> vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const array& cotan,
|
||||||
|
const std::vector<int>& argnums) override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The primitive must know how to vectorize itself accross
|
||||||
|
* the given axes. The output is a pair containing the array
|
||||||
|
* representing the vectorized computation and the axis which
|
||||||
|
* corresponds to the output vectorized dimension.
|
||||||
|
*/
|
||||||
|
std::pair<array, int> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
/** Print the primitive. */
|
||||||
|
void print(std::ostream& os) override {
|
||||||
|
os << "Axpby";
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Equivalence check **/
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
float alpha_;
|
||||||
|
float beta_;
|
||||||
|
|
||||||
|
/** Fall back implementation for evaluation on CPU */
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
|
The :class:`Axpby` class derives from the base :class:`Primitive` class and
|
||||||
|
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
|
||||||
|
``beta`` as parameters. It then provides implementations of how the array ``out``
|
||||||
|
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
|
||||||
|
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
|
||||||
|
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
|
||||||
|
|
||||||
|
Using the Primitives
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Operations can use this :class:`Primitive` to add a new :class:`array` to
|
||||||
|
the computation graph. An :class:`array` can be constructed by providing its
|
||||||
|
data type, shape, the :class:`Primitive` that computes it, and the
|
||||||
|
:class:`array` inputs that are passed to the primitive.
|
||||||
|
|
||||||
|
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
array axpby(
|
||||||
|
const array& x, // Input array x
|
||||||
|
const array& y, // Input array y
|
||||||
|
const float alpha, // Scaling factor for x
|
||||||
|
const float beta, // Scaling factor for y
|
||||||
|
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||||
|
) {
|
||||||
|
// Promote dtypes between x and y as needed
|
||||||
|
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||||
|
|
||||||
|
// Upcast to float32 for non-floating point inputs x and y
|
||||||
|
auto out_dtype = is_floating_point(promoted_dtype)
|
||||||
|
? promoted_dtype
|
||||||
|
: promote_types(promoted_dtype, float32);
|
||||||
|
|
||||||
|
// Cast x and y up to the determined dtype (on the same stream s)
|
||||||
|
auto x_casted = astype(x, out_dtype, s);
|
||||||
|
auto y_casted = astype(y, out_dtype, s);
|
||||||
|
|
||||||
|
// Broadcast the shapes of x and y (on the same stream s)
|
||||||
|
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||||
|
auto out_shape = broadcasted_inputs[0].shape();
|
||||||
|
|
||||||
|
// Construct the array as the output of the Axpby primitive
|
||||||
|
// with the broadcasted and upcasted arrays as inputs
|
||||||
|
return array(
|
||||||
|
/* const std::vector<int>& shape = */ out_shape,
|
||||||
|
/* Dtype dtype = */ out_dtype,
|
||||||
|
/* std::unique_ptr<Primitive> primitive = */
|
||||||
|
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||||
|
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
This operation now handles the following:
|
||||||
|
|
||||||
|
#. Upcast inputs and resolve the the output data type.
|
||||||
|
#. Broadcast the inputs and resolve the output shape.
|
||||||
|
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
||||||
|
#. Construct the output :class:`array` using the primitive and the inputs.
|
||||||
|
|
||||||
|
Implementing the Primitive
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
No computation happens when we call the operation alone. In effect, the
|
||||||
|
operation only builds the computation graph. When we evaluate the output
|
||||||
|
array, MLX schedules the execution of the computation graph, and calls
|
||||||
|
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
|
||||||
|
stream/device specified by the user.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
||||||
|
no memory has been allocated for the output array. It falls on the implementation
|
||||||
|
of these functions to allocate memory as needed
|
||||||
|
|
||||||
|
Implementing the CPU Backend
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Let's start by trying to implement a naive and generic version of
|
||||||
|
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||||
|
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||||
|
|
||||||
|
Our naive method will go over each element of the output array, find the
|
||||||
|
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||||
|
pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void axpby_impl(
|
||||||
|
const array& x,
|
||||||
|
const array& y,
|
||||||
|
array& out,
|
||||||
|
float alpha_,
|
||||||
|
float beta_) {
|
||||||
|
// We only allocate memory when we are ready to fill the output
|
||||||
|
// malloc_or_wait synchronously allocates available memory
|
||||||
|
// There may be a wait executed here if the allocation is requested
|
||||||
|
// under memory-pressured conditions
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
// Collect input and output data pointers
|
||||||
|
const T* x_ptr = x.data<T>();
|
||||||
|
const T* y_ptr = y.data<T>();
|
||||||
|
T* out_ptr = out.data<T>();
|
||||||
|
|
||||||
|
// Cast alpha and beta to the relevant types
|
||||||
|
T alpha = static_cast<T>(alpha_);
|
||||||
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
|
// Do the elementwise operation for each output
|
||||||
|
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||||
|
// Map linear indices to offsets in x and y
|
||||||
|
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||||
|
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||||
|
|
||||||
|
// We allocate the output to be contiguous and regularly strided
|
||||||
|
// (defaults to row major) and hence it doesn't need additonal mapping
|
||||||
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Now, we would like our implementation to be able to do this pointwise operation
|
||||||
|
for all incoming floating point arrays. Accordingly, we add dispatches for
|
||||||
|
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
|
||||||
|
if we encounter an unexpected type.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
/** Fall back implementation for evaluation on CPU */
|
||||||
|
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
// Check the inputs (registered in the op while contructing the out array)
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& y = inputs[1];
|
||||||
|
|
||||||
|
// Dispatch to the correct dtype
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||||
|
} else if (out.dtype() == float16) {
|
||||||
|
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||||
|
} else if (out.dtype() == bfloat16) {
|
||||||
|
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||||
|
} else if (out.dtype() == complex64) {
|
||||||
|
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Axpby is only supported for floating point types.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
We have a fallback implementation! Now, to do what we are really here to do.
|
||||||
|
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
|
||||||
|
framework? Well, there are 3 complications to keep in mind:
|
||||||
|
|
||||||
|
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||||
|
floats. We can only direct to it for ``float32`` types
|
||||||
|
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
|
||||||
|
have fixed strides between them. Possibly due to broadcasts and transposes,
|
||||||
|
we aren't guaranteed that the inputs fit this requirement. We can
|
||||||
|
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
|
||||||
|
column contiguous.
|
||||||
|
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
|
||||||
|
MLX expects to write out the answer to a new array. We must copy the elements
|
||||||
|
of ``y`` into the output array and use that as an input to ``axpby``
|
||||||
|
|
||||||
|
Let's write out an implementation that uses Accelerate in the right conditions.
|
||||||
|
It must simply allocate data for the output, copy elements of ``y`` into it,
|
||||||
|
and then call the :meth:`catlas_saxpby` from accelerate.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void axpby_impl_accelerate(
|
||||||
|
const array& x,
|
||||||
|
const array& y,
|
||||||
|
array& out,
|
||||||
|
float alpha_,
|
||||||
|
float beta_) {
|
||||||
|
// Accelerate library provides catlas_saxpby which does
|
||||||
|
// Y = (alpha * X) + (beta * Y) in place
|
||||||
|
// To use it, we first copy the data in y over to the output array
|
||||||
|
|
||||||
|
// This specialization requires both x and y be contiguous in the same mode
|
||||||
|
// i.e: corresponding linear indices in both point to corresponding elements
|
||||||
|
// The data in the output array is allocated to match the strides in y
|
||||||
|
// such that x, y, and out are contiguous in the same mode and
|
||||||
|
// no transposition is needed
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||||
|
y.data_size(),
|
||||||
|
y.strides(),
|
||||||
|
y.flags());
|
||||||
|
|
||||||
|
// We then copy over the elements using the contiguous vector specialization
|
||||||
|
copy_inplace(y, out, CopyType::Vector);
|
||||||
|
|
||||||
|
// Get x and y pointers for catlas_saxpby
|
||||||
|
const T* x_ptr = x.data<T>();
|
||||||
|
T* y_ptr = out.data<T>();
|
||||||
|
|
||||||
|
T alpha = static_cast<T>(alpha_);
|
||||||
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
|
// Call the inplace accelerate operator
|
||||||
|
catlas_saxpby(
|
||||||
|
/* N = */ out.size(),
|
||||||
|
/* ALPHA = */ alpha,
|
||||||
|
/* X = */ x_ptr,
|
||||||
|
/* INCX = */ 1,
|
||||||
|
/* BETA = */ beta,
|
||||||
|
/* Y = */ y_ptr,
|
||||||
|
/* INCY = */ 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Great! But what about the inputs that do not fit the criteria for accelerate?
|
||||||
|
Luckily, we can always just direct back to :meth:`Axpby::eval`.
|
||||||
|
|
||||||
|
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
/** Evaluate primitive on CPU using accelerate specializations */
|
||||||
|
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& y = inputs[1];
|
||||||
|
|
||||||
|
// Accelerate specialization for contiguous single precision float arrays
|
||||||
|
if (out.dtype() == float32 &&
|
||||||
|
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||||
|
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||||
|
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to common backend if specializations are not available
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
We have now hit a milestone! Just this much is enough to run the operation
|
||||||
|
:meth:`axpby` on a CPU stream!
|
||||||
|
|
||||||
|
If you do not plan on running the operation on the GPU or using transforms on
|
||||||
|
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||||
|
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||||
|
|
||||||
|
Implementing the GPU Backend
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||||
|
all GPU kernels in MLX are written using metal.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Here are some helpful resources if you are new to metal!
|
||||||
|
|
||||||
|
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
||||||
|
* Documentation for metal shading language: `Metal Specification`_
|
||||||
|
* Using metal from C++: `Metal-cpp`_
|
||||||
|
|
||||||
|
Let's keep the GPU algorithm simple. We will launch exactly as many threads
|
||||||
|
as there are elements in the output. Each thread will pick the element it needs
|
||||||
|
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
|
||||||
|
element in the output.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void axpby_general(
|
||||||
|
device const T* x [[buffer(0)]],
|
||||||
|
device const T* y [[buffer(1)]],
|
||||||
|
device T* out [[buffer(2)]],
|
||||||
|
constant const float& alpha [[buffer(3)]],
|
||||||
|
constant const float& beta [[buffer(4)]],
|
||||||
|
constant const int* shape [[buffer(5)]],
|
||||||
|
constant const size_t* x_strides [[buffer(6)]],
|
||||||
|
constant const size_t* y_strides [[buffer(7)]],
|
||||||
|
constant const int& ndim [[buffer(8)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
// Convert linear indices to offsets in array
|
||||||
|
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||||
|
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||||
|
|
||||||
|
// Do the operation and update the output
|
||||||
|
out[index] =
|
||||||
|
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
We then need to instantiate this template for all floating point types and give
|
||||||
|
each instantiation a unique host name so we can identify the right kernel for
|
||||||
|
each data type.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
#define instantiate_axpby(type_name, type) \
|
||||||
|
template [[host_name("axpby_general_" #type_name)]] \
|
||||||
|
[[kernel]] void axpby_general<type>( \
|
||||||
|
device const type* x [[buffer(0)]], \
|
||||||
|
device const type* y [[buffer(1)]], \
|
||||||
|
device type* out [[buffer(2)]], \
|
||||||
|
constant const float& alpha [[buffer(3)]], \
|
||||||
|
constant const float& beta [[buffer(4)]], \
|
||||||
|
constant const int* shape [[buffer(5)]], \
|
||||||
|
constant const size_t* x_strides [[buffer(6)]], \
|
||||||
|
constant const size_t* y_strides [[buffer(7)]], \
|
||||||
|
constant const int& ndim [[buffer(8)]], \
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
instantiate_axpby(float32, float);
|
||||||
|
instantiate_axpby(float16, half);
|
||||||
|
instantiate_axpby(bflot16, bfloat16_t);
|
||||||
|
instantiate_axpby(complex64, complex64_t);
|
||||||
|
|
||||||
|
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||||
|
will see later in :ref:`Building with CMake`. In the following example, we
|
||||||
|
assume that the library ``mlx_ext.metallib`` will always be co-located with
|
||||||
|
the executable/ shared-library calling the :meth:`register_library` function.
|
||||||
|
The :meth:`register_library` function takes the library's name and potential
|
||||||
|
path (or in this case, a function that can produce the path of the metal
|
||||||
|
library) and tries to load that library if it hasn't already been registered
|
||||||
|
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
|
||||||
|
it is important to package your C++ library with the metal library. We will
|
||||||
|
go over this process in more detail later.
|
||||||
|
|
||||||
|
The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
||||||
|
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||||
|
below.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
/** Evaluate primitive on GPU */
|
||||||
|
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
// Prepare inputs
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& y = inputs[1];
|
||||||
|
|
||||||
|
// Each primitive carries the stream it should execute on
|
||||||
|
// and each stream carries its device identifiers
|
||||||
|
auto& s = stream();
|
||||||
|
// We get the needed metal device using the stream
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
// Allocate output memory
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "axpby_" << "general_" << type_to_name(out);
|
||||||
|
|
||||||
|
// Make sure the metal library is available and look for it
|
||||||
|
// in the same folder as this executable if needed
|
||||||
|
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||||
|
|
||||||
|
// Make a kernel from this metal library
|
||||||
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
|
// Prepare to encode kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
|
// those in the kernel decelaration at axpby.metal
|
||||||
|
int ndim = out.ndim();
|
||||||
|
size_t nelem = out.size();
|
||||||
|
|
||||||
|
// Encode input arrays to kernel
|
||||||
|
set_array_buffer(compute_encoder, x, 0);
|
||||||
|
set_array_buffer(compute_encoder, y, 1);
|
||||||
|
|
||||||
|
// Encode output arrays to kernel
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
|
||||||
|
// Encode alpha and beta
|
||||||
|
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||||
|
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||||
|
|
||||||
|
// Encode shape, strides and ndim
|
||||||
|
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||||
|
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||||
|
|
||||||
|
// We launch 1 thread for each input and make sure that the number of
|
||||||
|
// threads in any given threadgroup is not higher than the max allowed
|
||||||
|
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
|
||||||
|
// Fix the 3D size of each threadgroup (in terms of threads)
|
||||||
|
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
|
||||||
|
|
||||||
|
// Fix the 3D size of the launch grid (in terms of threads)
|
||||||
|
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||||
|
|
||||||
|
// Launch the grid with the given number of threads divded among
|
||||||
|
// the given threadgroups
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||||
|
|
||||||
|
A few things to note about MLX and metal before moving on. MLX keeps track
|
||||||
|
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
|
||||||
|
to give us the active metal compute command encoder instead of building a
|
||||||
|
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||||
|
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||||
|
until some specified limit is hit or the compute encoder needs to be flushed
|
||||||
|
for synchronization. MLX also handles enqueuing and commiting the associated
|
||||||
|
command buffers as needed. We suggest taking a deeper dive into
|
||||||
|
:class:`metal::Device` if you would like to study this routine further.
|
||||||
|
|
||||||
|
Primitive Transforms
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Now that we have come this far, let's also learn how to add implementations to
|
||||||
|
transformations in a :class:`Primitive`. These transformations can be built on
|
||||||
|
top of our operations, including the one we just defined now. Which then gives
|
||||||
|
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
/** The Jacobian-vector product. */
|
||||||
|
array Axpby::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
// Forward mode diff that pushes along the tangents
|
||||||
|
// The jvp transform on the the primitive can built with ops
|
||||||
|
// that are scheduled on the same stream as the primtive
|
||||||
|
|
||||||
|
// If argnums = {0}, we only push along x in which case the
|
||||||
|
// jvp is just the tangent scaled by alpha
|
||||||
|
// Similarly, if argnums = {1}, the jvp is just the tangent
|
||||||
|
// scaled by beta
|
||||||
|
if (argnums.size() > 1) {
|
||||||
|
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||||
|
auto scale_arr = array(scale, tangents[0].dtype());
|
||||||
|
return multiply(scale_arr, tangents[0], stream());
|
||||||
|
}
|
||||||
|
// If, argnums = {0, 1}, we take contributions from both
|
||||||
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
|
else {
|
||||||
|
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
/** The vector-Jacobian product. */
|
||||||
|
std::vector<array> Axpby::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const array& cotan,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
// Reverse mode diff
|
||||||
|
std::vector<array> vjps;
|
||||||
|
for (auto arg : argnums) {
|
||||||
|
auto scale = arg == 0 ? alpha_ : beta_;
|
||||||
|
auto scale_arr = array(scale, cotan.dtype());
|
||||||
|
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
||||||
|
}
|
||||||
|
return vjps;
|
||||||
|
}
|
||||||
|
|
||||||
|
Finally, you need not have a transformation fully defined to start using your
|
||||||
|
own :class:`Primitive`.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
/** Vectorize primitve along given axis */
|
||||||
|
std::pair<array, int> Axpby::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Building and Binding
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
Let's look at the overall directory structure first.
|
||||||
|
|
||||||
|
| extensions
|
||||||
|
| ├── axpby
|
||||||
|
| │ ├── axpby.cpp
|
||||||
|
| │ ├── axpby.h
|
||||||
|
| │ └── axpby.metal
|
||||||
|
| ├── mlx_sample_extensions
|
||||||
|
| │ └── __init__.py
|
||||||
|
| ├── bindings.cpp
|
||||||
|
| ├── CMakeLists.txt
|
||||||
|
| └── setup.py
|
||||||
|
|
||||||
|
* ``extensions/axpby/`` defines the C++ extension library
|
||||||
|
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
|
||||||
|
associated python package
|
||||||
|
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||||
|
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||||
|
python bindings
|
||||||
|
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
||||||
|
the python package
|
||||||
|
|
||||||
|
Binding to Python
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
We use PyBind11_ to build a Python API for the C++ library. Since bindings
|
||||||
|
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
|
||||||
|
are already provided, adding our :meth:`axpby` becomes very simple!
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||||
|
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"axpby",
|
||||||
|
&axpby,
|
||||||
|
"x"_a,
|
||||||
|
"y"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"alpha"_a,
|
||||||
|
"beta"_a,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = py::none(),
|
||||||
|
R"pbdoc(
|
||||||
|
Scale and sum two vectors elementwise
|
||||||
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
|
Inputs are upcasted to floats if needed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): Input array.
|
||||||
|
y (array): Input array.
|
||||||
|
alpha (float): Scaling factor for ``x``.
|
||||||
|
beta (float): Scaling factor for ``y``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: ``alpha * x + beta * y``
|
||||||
|
)pbdoc");
|
||||||
|
}
|
||||||
|
|
||||||
|
Most of the complexity in the above example comes from additional bells and
|
||||||
|
whistles such as the literal names and doc-strings.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
:mod:`mlx.core` needs to be imported before importing
|
||||||
|
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
|
||||||
|
ensure that the casters for :mod:`mlx.core` components like
|
||||||
|
:class:`mlx.core.array` are available.
|
||||||
|
|
||||||
|
.. _Building with CMake:
|
||||||
|
|
||||||
|
Building with CMake
|
||||||
|
^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Building the C++ extension library itself is simple, it only requires that you
|
||||||
|
``find_package(MLX CONFIG)`` and then link it to your library.
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
# Add library
|
||||||
|
add_library(mlx_ext)
|
||||||
|
|
||||||
|
# Add sources
|
||||||
|
target_sources(
|
||||||
|
mlx_ext
|
||||||
|
PUBLIC
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add include headers
|
||||||
|
target_include_directories(
|
||||||
|
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link to mlx
|
||||||
|
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||||
|
|
||||||
|
We also need to build the attached metal library. For convenience, we provide a
|
||||||
|
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||||
|
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||||
|
automatically imported with MLX package).
|
||||||
|
|
||||||
|
Here is what that looks like in practice!
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
# Build metallib
|
||||||
|
if(MLX_BUILD_METAL)
|
||||||
|
|
||||||
|
mlx_build_metallib(
|
||||||
|
TARGET mlx_ext_metallib
|
||||||
|
TITLE mlx_ext
|
||||||
|
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||||
|
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
||||||
|
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
||||||
|
)
|
||||||
|
|
||||||
|
add_dependencies(
|
||||||
|
mlx_ext
|
||||||
|
mlx_ext_metallib
|
||||||
|
)
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
||||||
|
Finally, we build the Pybind11_ bindings
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
pybind11_add_module(
|
||||||
|
mlx_sample_extensions
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||||
|
|
||||||
|
if(BUILD_SHARED_LIBS)
|
||||||
|
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
Building with ``setuptools``
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Once we have set out the CMake build rules as described above, we can use the
|
||||||
|
build utilities defined in :mod:`mlx.extension` for a simple build process.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx import extension
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup(
|
||||||
|
name="mlx_sample_extensions",
|
||||||
|
version="0.0.0",
|
||||||
|
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||||
|
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||||
|
cmdclass={"build_ext": extension.CMakeBuild},
|
||||||
|
packages = ["mlx_sample_extensions"],
|
||||||
|
package_dir = {"": "mlx_sample_extensions"},
|
||||||
|
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
|
||||||
|
zip_safe=False,
|
||||||
|
python_requires=">=3.7",
|
||||||
|
)
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
||||||
|
even though it only contains a ``__init__.py`` to ensure the following:
|
||||||
|
|
||||||
|
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
|
||||||
|
* The C++ extension library and the metal library are co-located with the python
|
||||||
|
bindings and copied together if the package is installed
|
||||||
|
|
||||||
|
You can build inplace for development using
|
||||||
|
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
||||||
|
|
||||||
|
This will result in a directory structure as follows:
|
||||||
|
|
||||||
|
| extensions
|
||||||
|
| ├── mlx_sample_extensions
|
||||||
|
| │ ├── __init__.py
|
||||||
|
| │ ├── libmlx_ext.dylib # C++ extension library
|
||||||
|
| │ ├── mlx_ext.metallib # Metal library
|
||||||
|
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
|
||||||
|
| ...
|
||||||
|
|
||||||
|
When you try to install using the command ``python -m pip install .``
|
||||||
|
(in ``extensions/``), the package will be installed with the same strucutre as
|
||||||
|
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||||
|
copied along with the python binding since they are specified as ``package_data``.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
|
||||||
|
After installing the extension as described above, you should be able to simply
|
||||||
|
import the python package and play with it as you would any other MLX operation!
|
||||||
|
|
||||||
|
Let's looks at a simple script and it's results!
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_sample_extensions import axpby
|
||||||
|
|
||||||
|
a = mx.ones((3, 4))
|
||||||
|
b = mx.ones((3, 4))
|
||||||
|
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||||
|
|
||||||
|
print(f"c shape: {c.shape}")
|
||||||
|
print(f"c dtype: {c.dtype}")
|
||||||
|
print(f"c correctness: {mx.all(c == 6.0).item()}")
|
||||||
|
|
||||||
|
Output:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
c shape: [3, 4]
|
||||||
|
c dtype: float32
|
||||||
|
c correctness: True
|
||||||
|
|
||||||
|
Results
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||||
|
with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_sample_extensions import axpby
|
||||||
|
import time
|
||||||
|
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||||
|
return alpha * x + beta * y
|
||||||
|
|
||||||
|
M = 256
|
||||||
|
N = 512
|
||||||
|
|
||||||
|
x = mx.random.normal((M, N))
|
||||||
|
y = mx.random.normal((M, N))
|
||||||
|
alpha = 4.0
|
||||||
|
beta = 2.0
|
||||||
|
|
||||||
|
mx.eval((x, y))
|
||||||
|
|
||||||
|
def bench(f):
|
||||||
|
# Warm up
|
||||||
|
for i in range(100):
|
||||||
|
z = f(x, y, alpha, beta)
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
# Timed run
|
||||||
|
s = time.time()
|
||||||
|
for i in range(5000):
|
||||||
|
z = f(x, y, alpha, beta)
|
||||||
|
mx.eval(z)
|
||||||
|
e = time.time()
|
||||||
|
return e - s
|
||||||
|
|
||||||
|
simple_time = bench(simple_axpby)
|
||||||
|
custom_time = bench(axpby)
|
||||||
|
|
||||||
|
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||||
|
|
||||||
|
Results:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Simple axpby: 0.114 s | Custom axpby: 0.109 s
|
||||||
|
|
||||||
|
We see some modest improvements right away!
|
||||||
|
|
||||||
|
This operation is now good to be used to build other operations,
|
||||||
|
in :class:`mlx.nn.Module` calls, and also as a part of graph
|
||||||
|
transformations such as :meth:`grad` and :meth:`simplify`!
|
||||||
|
|
||||||
|
Scripts
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. admonition:: Download the code
|
||||||
|
|
||||||
|
The full example code is available in `mlx-examples <code>`_.
|
||||||
|
|
||||||
|
.. code: `TODO_LINK/extensions`_
|
||||||
|
|
||||||
|
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||||
|
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||||
|
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||||
|
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||||
|
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||||
|
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
77
docs/src/examples/linear_regression.rst
Normal file
77
docs/src/examples/linear_regression.rst
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
.. _linear_regression:
|
||||||
|
|
||||||
|
Linear Regression
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
Let's implement a basic linear regression model as a starting point to
|
||||||
|
learn MLX. First import the core package and setup some problem metadata:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
num_features = 100
|
||||||
|
num_examples = 1_000
|
||||||
|
num_iters = 10_000 # iterations of SGD
|
||||||
|
lr = 0.01 # learning rate for SGD
|
||||||
|
|
||||||
|
|
||||||
|
We'll generate a synthetic dataset by:
|
||||||
|
|
||||||
|
1. Sampling the design matrix ``X``.
|
||||||
|
2. Sampling a ground truth parameter vector ``w_star``.
|
||||||
|
3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# True parameters
|
||||||
|
w_star = mx.random.normal((num_features,))
|
||||||
|
|
||||||
|
# Input examples (design matrix)
|
||||||
|
X = mx.random.normal((num_examples, num_features))
|
||||||
|
|
||||||
|
# Noisy labels
|
||||||
|
eps = 1e-2 * mx.random.normal((num_examples,))
|
||||||
|
y = X @ w_star + eps
|
||||||
|
|
||||||
|
|
||||||
|
We will use SGD to find the optimal weights. To start, define the squared loss
|
||||||
|
and get the gradient function of the loss with respect to the parameters.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def loss_fn(w):
|
||||||
|
return 0.5 * mx.mean(mx.square(X @ w - y))
|
||||||
|
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
|
||||||
|
Start the optimization by initializing the parameters ``w`` randomly. Then
|
||||||
|
repeatedly update the parameters for ``num_iters`` iterations.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
w = 1e-2 * mx.random.normal((num_features,))
|
||||||
|
|
||||||
|
for _ in range(num_iters):
|
||||||
|
grad = grad_fn(w)
|
||||||
|
w = w - lr * grad
|
||||||
|
mx.eval(w)
|
||||||
|
|
||||||
|
Finally, compute the loss of the learned parameters and verify that they are
|
||||||
|
close to the ground truth parameters.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
loss = loss_fn(w)
|
||||||
|
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
|
||||||
|
)
|
||||||
|
# Should print something close to: Loss 0.00005, |w-w*| = 0.00364
|
||||||
|
|
||||||
|
Complete `linear regression
|
||||||
|
<https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py>`_
|
||||||
|
and `logistic regression
|
||||||
|
<https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py>`_
|
||||||
|
examples are available in the MLX GitHub repo.
|
52
docs/src/python/data_types.rst
Normal file
52
docs/src/python/data_types.rst
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
.. _data_types:
|
||||||
|
|
||||||
|
:orphan:
|
||||||
|
|
||||||
|
Data Types
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
The default floating point type is ``float32`` and the default integer type is
|
||||||
|
``int32``. The table below shows supported values for :obj:`Dtype`.
|
||||||
|
|
||||||
|
.. list-table:: Supported Data Types
|
||||||
|
:widths: 5 3 20
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Type
|
||||||
|
- Bytes
|
||||||
|
- Description
|
||||||
|
* - ``bool_``
|
||||||
|
- 1
|
||||||
|
- Boolean (``True``, ``False``) data type
|
||||||
|
* - ``uint8``
|
||||||
|
- 1
|
||||||
|
- 8-bit unsigned integer
|
||||||
|
* - ``uint16``
|
||||||
|
- 2
|
||||||
|
- 16-bit unsigned integer
|
||||||
|
* - ``uint32``
|
||||||
|
- 4
|
||||||
|
- 32-bit unsigned integer
|
||||||
|
* - ``uint32``
|
||||||
|
- 8
|
||||||
|
- 32-bit unsigned integer
|
||||||
|
* - ``int8``
|
||||||
|
- 1
|
||||||
|
- 8-bit signed integer
|
||||||
|
* - ``int16``
|
||||||
|
- 2
|
||||||
|
- 16-bit signed integer
|
||||||
|
* - ``int32``
|
||||||
|
- 4
|
||||||
|
- 32-bit signed integer
|
||||||
|
* - ``int64``
|
||||||
|
- 8
|
||||||
|
- 64-bit signed integer
|
||||||
|
* - ``float16``
|
||||||
|
- 2
|
||||||
|
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
||||||
|
* - ``float32``
|
||||||
|
- 4
|
||||||
|
- 32-bit float
|
17
docs/src/python/devices_and_streams.rst
Normal file
17
docs/src/python/devices_and_streams.rst
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
.. _devices_and_streams:
|
||||||
|
|
||||||
|
Devices and Streams
|
||||||
|
===================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Device
|
||||||
|
default_device
|
||||||
|
set_default_device
|
||||||
|
Stream
|
||||||
|
default_stream
|
||||||
|
new_stream
|
||||||
|
set_default_stream
|
16
docs/src/python/transforms.rst
Normal file
16
docs/src/python/transforms.rst
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
.. _transforms:
|
||||||
|
|
||||||
|
Transforms
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
eval
|
||||||
|
grad
|
||||||
|
value_and_grad
|
||||||
|
jvp
|
||||||
|
vjp
|
||||||
|
vmap
|
21
docs/src/python/tree_utils.rst
Normal file
21
docs/src/python/tree_utils.rst
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
.. _utils:
|
||||||
|
|
||||||
|
Tree Utils
|
||||||
|
==========
|
||||||
|
|
||||||
|
In MLX we consider a python tree to be an arbitrarily nested collection of
|
||||||
|
dictionaries, lists and tuples without cycles. Functions in this module that
|
||||||
|
return python trees will be using the default python ``dict``, ``list`` and
|
||||||
|
``tuple`` but they can usually process objects that inherit from any of these.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Dictionaries should have keys that are valid python identifiers.
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.utils
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
tree_flatten
|
||||||
|
tree_unflatten
|
||||||
|
tree_map
|
10
examples/cpp/CMakeLists.txt
Normal file
10
examples/cpp/CMakeLists.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
function(build_example SRCFILE)
|
||||||
|
get_filename_component(src_name ${SRCFILE} NAME_WE)
|
||||||
|
set(target "${src_name}")
|
||||||
|
add_executable(${target} ${SRCFILE})
|
||||||
|
target_link_libraries(${target} PRIVATE mlx)
|
||||||
|
endfunction(build_example)
|
||||||
|
|
||||||
|
build_example(tutorial.cpp)
|
||||||
|
build_example(linear_regression.cpp)
|
||||||
|
build_example(logistic_regression.cpp)
|
52
examples/cpp/linear_regression.cpp
Normal file
52
examples/cpp/linear_regression.cpp
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
#include <chrono>
|
||||||
|
#include <cmath>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
#include "timer.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An example of linear regression with MLX.
|
||||||
|
*/
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
int num_features = 100;
|
||||||
|
int num_examples = 1'000;
|
||||||
|
int num_iters = 10'000;
|
||||||
|
float learning_rate = 0.01;
|
||||||
|
|
||||||
|
// True parameters
|
||||||
|
auto w_star = random::normal({num_features});
|
||||||
|
|
||||||
|
// The input examples (design matrix)
|
||||||
|
auto X = random::normal({num_examples, num_features});
|
||||||
|
|
||||||
|
// Noisy labels
|
||||||
|
auto eps = 1e-2 * random::normal({num_examples});
|
||||||
|
auto y = matmul(X, w_star) + eps;
|
||||||
|
|
||||||
|
// Initialize random parameters
|
||||||
|
array w = 1e-2 * random::normal({num_features});
|
||||||
|
|
||||||
|
auto loss_fn = [&](array w) {
|
||||||
|
auto yhat = matmul(X, w);
|
||||||
|
return (0.5f / num_examples) * sum(square(yhat - y));
|
||||||
|
};
|
||||||
|
|
||||||
|
auto grad_fn = grad(loss_fn);
|
||||||
|
|
||||||
|
auto tic = timer::time();
|
||||||
|
for (int it = 0; it < num_iters; ++it) {
|
||||||
|
auto grad = grad_fn(w);
|
||||||
|
w = w - learning_rate * grad;
|
||||||
|
eval(w);
|
||||||
|
}
|
||||||
|
auto toc = timer::time();
|
||||||
|
|
||||||
|
auto loss = loss_fn(w);
|
||||||
|
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
||||||
|
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||||
|
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||||
|
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||||
|
}
|
3581
mlx/3rdparty/pocketfft.h
vendored
Normal file
3581
mlx/3rdparty/pocketfft.h
vendored
Normal file
File diff suppressed because it is too large
Load Diff
64
mlx/allocator.h
Normal file
64
mlx/allocator.h
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
|
// Simple wrapper around buffer pointers
|
||||||
|
// WARNING: Only Buffer objects constructed from and those that wrap
|
||||||
|
// raw pointers from mlx::allocator are supported.
|
||||||
|
class Buffer {
|
||||||
|
private:
|
||||||
|
void* ptr_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Buffer(void* ptr) : ptr_(ptr){};
|
||||||
|
|
||||||
|
// Get the raw data pointer from the buffer
|
||||||
|
void* raw_ptr();
|
||||||
|
|
||||||
|
// Get the buffer pointer from the buffer
|
||||||
|
const void* ptr() const {
|
||||||
|
return ptr_;
|
||||||
|
};
|
||||||
|
void* ptr() {
|
||||||
|
return ptr_;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
Buffer malloc(size_t size);
|
||||||
|
|
||||||
|
void free(Buffer buffer);
|
||||||
|
|
||||||
|
// Wait for running tasks to finish and free up memory
|
||||||
|
// if allocation fails
|
||||||
|
Buffer malloc_or_wait(size_t size);
|
||||||
|
|
||||||
|
class Allocator {
|
||||||
|
/** Abstract base clase for a memory allocator. */
|
||||||
|
public:
|
||||||
|
virtual Buffer malloc(size_t size) = 0;
|
||||||
|
virtual void free(Buffer buffer) = 0;
|
||||||
|
|
||||||
|
Allocator() = default;
|
||||||
|
Allocator(const Allocator& other) = delete;
|
||||||
|
Allocator(Allocator&& other) = delete;
|
||||||
|
Allocator& operator=(const Allocator& other) = delete;
|
||||||
|
Allocator& operator=(Allocator&& other) = delete;
|
||||||
|
virtual ~Allocator() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
Allocator& allocator();
|
||||||
|
|
||||||
|
class CommonAllocator : public Allocator {
|
||||||
|
/** A general CPU allocator. */
|
||||||
|
public:
|
||||||
|
virtual Buffer malloc(size_t size) override;
|
||||||
|
virtual void free(Buffer buffer) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
CommonAllocator() = default;
|
||||||
|
friend Allocator& allocator();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::allocator
|
18
mlx/backend/accelerate/conv.cpp
Normal file
18
mlx/backend/accelerate/conv.cpp
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <simd/vector.h>
|
||||||
|
#include <vecLib/vDSP.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
|
||||||
|
// TODO: Add accelerate based optimizations for CPU conv
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
26
mlx/backend/accelerate/utils.h
Normal file
26
mlx/backend/accelerate/utils.h
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vecLib/BNNS/bnns.h>
|
||||||
|
#include "mlx/dtype.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
||||||
|
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
||||||
|
switch (kindof(mlx_dtype)) {
|
||||||
|
case Dtype::Kind::b:
|
||||||
|
return BNNSDataTypeBoolean;
|
||||||
|
case Dtype::Kind::u:
|
||||||
|
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
||||||
|
case Dtype::Kind::i:
|
||||||
|
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
||||||
|
case Dtype::Kind::f:
|
||||||
|
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
||||||
|
case Dtype::Kind::V:
|
||||||
|
return BNNSDataTypeBFloat16;
|
||||||
|
case Dtype::Kind::c:
|
||||||
|
throw std::invalid_argument("BNNS does not support complex types");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
27
mlx/backend/common/copy.h
Normal file
27
mlx/backend/common/copy.h
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
enum class CopyType {
|
||||||
|
// Copy a raw scalar input into the full contiguous output
|
||||||
|
Scalar,
|
||||||
|
|
||||||
|
// Copy the raw input buffer contiguously into a raw output buffer of the same
|
||||||
|
// size
|
||||||
|
Vector,
|
||||||
|
|
||||||
|
// Copy the full virtual input to the full contiguous output
|
||||||
|
General,
|
||||||
|
|
||||||
|
// Copy the full virtual input to the full virtual output. We assume the
|
||||||
|
// input and output have the same shape.
|
||||||
|
GeneralGeneral
|
||||||
|
};
|
||||||
|
|
||||||
|
void copy(const array& src, array& dst, CopyType ctype);
|
||||||
|
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
394
mlx/backend/common/sort.cpp
Normal file
394
mlx/backend/common/sort.cpp
Normal file
@ -0,0 +1,394 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, typename IdxT = int32_t>
|
||||||
|
struct StridedIterator {
|
||||||
|
using iterator_category = std::random_access_iterator_tag;
|
||||||
|
using difference_type = IdxT;
|
||||||
|
using value_type = T;
|
||||||
|
using reference = value_type&;
|
||||||
|
using pointer = value_type*;
|
||||||
|
|
||||||
|
// Constructors
|
||||||
|
StridedIterator() = default;
|
||||||
|
|
||||||
|
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)
|
||||||
|
: ptr_(ptr + offset * stride), stride_(stride) {}
|
||||||
|
|
||||||
|
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
|
||||||
|
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
|
||||||
|
|
||||||
|
// Accessors
|
||||||
|
reference operator*() const {
|
||||||
|
return ptr_[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
reference operator[](difference_type idx) const {
|
||||||
|
return ptr_[idx * stride_];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comparisons
|
||||||
|
bool operator==(const StridedIterator& other) const {
|
||||||
|
return ptr_ == other.ptr_ && stride_ == other.stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(const StridedIterator& other) const {
|
||||||
|
return ptr_ != other.ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator<(const StridedIterator& other) const {
|
||||||
|
return ptr_ < other.ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator>(const StridedIterator& other) const {
|
||||||
|
return ptr_ > other.ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator<=(const StridedIterator& other) const {
|
||||||
|
return ptr_ <= other.ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator>=(const StridedIterator& other) const {
|
||||||
|
return ptr_ >= other.ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
difference_type operator-(const StridedIterator& other) const {
|
||||||
|
return (ptr_ - other.ptr_) / stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Moving
|
||||||
|
StridedIterator& operator++() {
|
||||||
|
ptr_ += stride_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
StridedIterator& operator--() {
|
||||||
|
ptr_ -= stride_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
StridedIterator& operator+=(difference_type diff) {
|
||||||
|
ptr_ += diff * stride_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
StridedIterator& operator-=(difference_type diff) {
|
||||||
|
ptr_ -= diff * stride_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
StridedIterator operator+(difference_type diff) {
|
||||||
|
return StridedIterator(ptr_, stride_, diff);
|
||||||
|
}
|
||||||
|
|
||||||
|
StridedIterator operator-(difference_type diff) {
|
||||||
|
return StridedIterator(ptr_, stride_, -diff);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t stride_;
|
||||||
|
T* ptr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename IdxT = uint32_t>
|
||||||
|
void sort(const array& in, array& out, int axis) {
|
||||||
|
// Copy input to output
|
||||||
|
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
copy(in, out, ctype);
|
||||||
|
|
||||||
|
// Get axis, shape and stride info
|
||||||
|
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||||
|
size_t n_rows = in.size() / in.shape(axis);
|
||||||
|
|
||||||
|
auto remaining_shape = in.shape();
|
||||||
|
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||||
|
|
||||||
|
auto remaining_strides = in.strides();
|
||||||
|
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||||
|
|
||||||
|
size_t axis_stride = in.strides()[axis];
|
||||||
|
int axis_size = in.shape(axis);
|
||||||
|
|
||||||
|
// Perform sorting in place
|
||||||
|
for (int i = 0; i < n_rows; i++) {
|
||||||
|
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||||
|
T* data_ptr = out.data<T>() + loc;
|
||||||
|
|
||||||
|
StridedIterator st(data_ptr, axis_stride, 0);
|
||||||
|
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
|
std::stable_sort(st, ed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename IdxT = uint32_t>
|
||||||
|
void argsort(const array& in, array& out, int axis) {
|
||||||
|
// Allocate output
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
// Get axis, shape and stride info
|
||||||
|
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||||
|
size_t n_rows = in.size() / in.shape(axis);
|
||||||
|
|
||||||
|
auto remaining_shape = in.shape();
|
||||||
|
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||||
|
|
||||||
|
auto remaining_strides = in.strides();
|
||||||
|
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||||
|
|
||||||
|
size_t axis_stride = in.strides()[axis];
|
||||||
|
int axis_size = in.shape(axis);
|
||||||
|
|
||||||
|
// Perform sorting
|
||||||
|
for (int i = 0; i < n_rows; i++) {
|
||||||
|
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||||
|
const T* data_ptr = in.data<T>() + loc;
|
||||||
|
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||||
|
|
||||||
|
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||||
|
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
|
// Initialize with iota
|
||||||
|
std::iota(st_, ed_, IdxT(0));
|
||||||
|
|
||||||
|
// Sort according to vals
|
||||||
|
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||||
|
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
|
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||||
|
auto v1 = data_ptr[a * axis_stride];
|
||||||
|
auto v2 = data_ptr[b * axis_stride];
|
||||||
|
return v1 < v2 || (v1 == v2 && a < b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename IdxT = uint32_t>
|
||||||
|
void partition(const array& in, array& out, int axis, int kth) {
|
||||||
|
// Copy input to output
|
||||||
|
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
copy(in, out, ctype);
|
||||||
|
|
||||||
|
// Get axis, shape and stride info
|
||||||
|
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||||
|
size_t n_rows = in.size() / in.shape(axis);
|
||||||
|
|
||||||
|
auto remaining_shape = in.shape();
|
||||||
|
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||||
|
|
||||||
|
auto remaining_strides = in.strides();
|
||||||
|
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||||
|
|
||||||
|
size_t axis_stride = in.strides()[axis];
|
||||||
|
int axis_size = in.shape(axis);
|
||||||
|
|
||||||
|
kth = kth < 0 ? kth + axis_size : kth;
|
||||||
|
|
||||||
|
// Perform partition in place
|
||||||
|
for (int i = 0; i < n_rows; i++) {
|
||||||
|
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||||
|
T* data_ptr = out.data<T>() + loc;
|
||||||
|
|
||||||
|
StridedIterator st(data_ptr, axis_stride, 0);
|
||||||
|
StridedIterator md(data_ptr, axis_stride, kth);
|
||||||
|
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
|
std::nth_element(st, md, ed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename IdxT = uint32_t>
|
||||||
|
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||||
|
// Allocate output
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
// Get axis, shape and stride info
|
||||||
|
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||||
|
size_t n_rows = in.size() / in.shape(axis);
|
||||||
|
|
||||||
|
auto remaining_shape = in.shape();
|
||||||
|
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||||
|
|
||||||
|
auto remaining_strides = in.strides();
|
||||||
|
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||||
|
|
||||||
|
size_t axis_stride = in.strides()[axis];
|
||||||
|
int axis_size = in.shape(axis);
|
||||||
|
|
||||||
|
kth = kth < 0 ? kth + axis_size : kth;
|
||||||
|
|
||||||
|
// Perform partition
|
||||||
|
for (int i = 0; i < n_rows; i++) {
|
||||||
|
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||||
|
const T* data_ptr = in.data<T>() + loc;
|
||||||
|
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||||
|
|
||||||
|
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||||
|
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
|
// Initialize with iota
|
||||||
|
std::iota(st_, ed_, IdxT(0));
|
||||||
|
|
||||||
|
// Sort according to vals
|
||||||
|
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||||
|
StridedIterator md(idx_ptr, axis_stride, kth);
|
||||||
|
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
|
std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||||
|
auto v1 = data_ptr[a * axis_stride];
|
||||||
|
auto v2 = data_ptr[b * axis_stride];
|
||||||
|
return v1 < v2 || (v1 == v2 && a < b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ArgSort::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
return argsort<bool>(in, out, axis_);
|
||||||
|
case uint8:
|
||||||
|
return argsort<uint8_t>(in, out, axis_);
|
||||||
|
case uint16:
|
||||||
|
return argsort<uint16_t>(in, out, axis_);
|
||||||
|
case uint32:
|
||||||
|
return argsort<uint32_t>(in, out, axis_);
|
||||||
|
case uint64:
|
||||||
|
return argsort<uint64_t>(in, out, axis_);
|
||||||
|
case int8:
|
||||||
|
return argsort<int8_t>(in, out, axis_);
|
||||||
|
case int16:
|
||||||
|
return argsort<int16_t>(in, out, axis_);
|
||||||
|
case int32:
|
||||||
|
return argsort<int32_t>(in, out, axis_);
|
||||||
|
case int64:
|
||||||
|
return argsort<int64_t>(in, out, axis_);
|
||||||
|
case float32:
|
||||||
|
return argsort<float>(in, out, axis_);
|
||||||
|
case float16:
|
||||||
|
return argsort<float16_t>(in, out, axis_);
|
||||||
|
case bfloat16:
|
||||||
|
return argsort<bfloat16_t>(in, out, axis_);
|
||||||
|
case complex64:
|
||||||
|
return argsort<complex64_t>(in, out, axis_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sort::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
return sort<bool>(in, out, axis_);
|
||||||
|
case uint8:
|
||||||
|
return sort<uint8_t>(in, out, axis_);
|
||||||
|
case uint16:
|
||||||
|
return sort<uint16_t>(in, out, axis_);
|
||||||
|
case uint32:
|
||||||
|
return sort<uint32_t>(in, out, axis_);
|
||||||
|
case uint64:
|
||||||
|
return sort<uint64_t>(in, out, axis_);
|
||||||
|
case int8:
|
||||||
|
return sort<int8_t>(in, out, axis_);
|
||||||
|
case int16:
|
||||||
|
return sort<int16_t>(in, out, axis_);
|
||||||
|
case int32:
|
||||||
|
return sort<int32_t>(in, out, axis_);
|
||||||
|
case int64:
|
||||||
|
return sort<int64_t>(in, out, axis_);
|
||||||
|
case float32:
|
||||||
|
return sort<float>(in, out, axis_);
|
||||||
|
case float16:
|
||||||
|
return sort<float16_t>(in, out, axis_);
|
||||||
|
case bfloat16:
|
||||||
|
return sort<bfloat16_t>(in, out, axis_);
|
||||||
|
case complex64:
|
||||||
|
return sort<complex64_t>(in, out, axis_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArgPartition::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
return argpartition<bool>(in, out, axis_, kth_);
|
||||||
|
case uint8:
|
||||||
|
return argpartition<uint8_t>(in, out, axis_, kth_);
|
||||||
|
case uint16:
|
||||||
|
return argpartition<uint16_t>(in, out, axis_, kth_);
|
||||||
|
case uint32:
|
||||||
|
return argpartition<uint32_t>(in, out, axis_, kth_);
|
||||||
|
case uint64:
|
||||||
|
return argpartition<uint64_t>(in, out, axis_, kth_);
|
||||||
|
case int8:
|
||||||
|
return argpartition<int8_t>(in, out, axis_, kth_);
|
||||||
|
case int16:
|
||||||
|
return argpartition<int16_t>(in, out, axis_, kth_);
|
||||||
|
case int32:
|
||||||
|
return argpartition<int32_t>(in, out, axis_, kth_);
|
||||||
|
case int64:
|
||||||
|
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||||
|
case float32:
|
||||||
|
return argpartition<float>(in, out, axis_, kth_);
|
||||||
|
case float16:
|
||||||
|
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||||
|
case bfloat16:
|
||||||
|
return argpartition<bfloat16_t>(in, out, axis_, kth_);
|
||||||
|
case complex64:
|
||||||
|
return argpartition<complex64_t>(in, out, axis_, kth_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Partition::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
return partition<bool>(in, out, axis_, kth_);
|
||||||
|
case uint8:
|
||||||
|
return partition<uint8_t>(in, out, axis_, kth_);
|
||||||
|
case uint16:
|
||||||
|
return partition<uint16_t>(in, out, axis_, kth_);
|
||||||
|
case uint32:
|
||||||
|
return partition<uint32_t>(in, out, axis_, kth_);
|
||||||
|
case uint64:
|
||||||
|
return partition<uint64_t>(in, out, axis_, kth_);
|
||||||
|
case int8:
|
||||||
|
return partition<int8_t>(in, out, axis_, kth_);
|
||||||
|
case int16:
|
||||||
|
return partition<int16_t>(in, out, axis_, kth_);
|
||||||
|
case int32:
|
||||||
|
return partition<int32_t>(in, out, axis_, kth_);
|
||||||
|
case int64:
|
||||||
|
return partition<int64_t>(in, out, axis_, kth_);
|
||||||
|
case float32:
|
||||||
|
return partition<float>(in, out, axis_, kth_);
|
||||||
|
case float16:
|
||||||
|
return partition<float16_t>(in, out, axis_, kth_);
|
||||||
|
case bfloat16:
|
||||||
|
return partition<bfloat16_t>(in, out, axis_, kth_);
|
||||||
|
case complex64:
|
||||||
|
return partition<complex64_t>(in, out, axis_, kth_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
257
mlx/backend/metal/device.cpp
Normal file
257
mlx/backend/metal/device.cpp
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
#include <dlfcn.h>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#define NS_PRIVATE_IMPLEMENTATION
|
||||||
|
#define CA_PRIVATE_IMPLEMENTATION
|
||||||
|
#define MTL_PRIVATE_IMPLEMENTATION
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/backend/metal/mps/gemm.h"
|
||||||
|
|
||||||
|
namespace fs = std::filesystem;
|
||||||
|
|
||||||
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
static Device metal_device_;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// TODO nicer way to set this or possibly expose as an environment variable
|
||||||
|
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||||
|
|
||||||
|
static constexpr const char* default_mtllib_path = METAL_PATH;
|
||||||
|
|
||||||
|
auto load_device() {
|
||||||
|
MTL::Device* device = MTL::CreateSystemDefaultDevice();
|
||||||
|
if (!device) {
|
||||||
|
throw std::runtime_error("Failed to load device");
|
||||||
|
}
|
||||||
|
return device;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||||
|
MTL::Device* device,
|
||||||
|
const char* path) {
|
||||||
|
auto library = NS::String::string(path, NS::UTF8StringEncoding);
|
||||||
|
NS::Error* error;
|
||||||
|
auto lib = device->newLibrary(library, &error);
|
||||||
|
|
||||||
|
return std::make_pair(lib, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Library* load_library(
|
||||||
|
MTL::Device* device,
|
||||||
|
const std::string& lib_name = "mlx",
|
||||||
|
const char* lib_path = default_mtllib_path) {
|
||||||
|
// Firstly, search for the metallib in the same path as this binary
|
||||||
|
std::string first_path = get_colocated_mtllib_path(lib_name);
|
||||||
|
if (first_path.size() != 0) {
|
||||||
|
auto [lib, error] = load_library_from_path(device, first_path.c_str());
|
||||||
|
if (lib) {
|
||||||
|
return lib;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Couldn't find it so let's load it from default_mtllib_path
|
||||||
|
{
|
||||||
|
auto [lib, error] = load_library_from_path(device, lib_path);
|
||||||
|
if (!lib) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << error->localizedDescription()->utf8String() << "\n"
|
||||||
|
<< "Failed to load device library from <" << lib_path << ">"
|
||||||
|
<< " or <" << first_path << ">.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return lib;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Device::Device()
|
||||||
|
: pool_(NS::AutoreleasePool::alloc()->init()),
|
||||||
|
device_(load_device()),
|
||||||
|
library_map_({{"mlx", load_library(device_)}}) {}
|
||||||
|
|
||||||
|
Device::~Device() {
|
||||||
|
for (auto& q : queue_map_) {
|
||||||
|
q.second->release();
|
||||||
|
}
|
||||||
|
for (auto& k : kernel_map_) {
|
||||||
|
k.second->release();
|
||||||
|
}
|
||||||
|
for (auto& l : library_map_) {
|
||||||
|
l.second->release();
|
||||||
|
}
|
||||||
|
device_->release();
|
||||||
|
pool_->release();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::new_queue(int index) {
|
||||||
|
// Multiple threads can ask the device for queues
|
||||||
|
// We lock this as a critical section for safety
|
||||||
|
const std::lock_guard<std::mutex> lock(mtx_);
|
||||||
|
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||||
|
if (!q) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[metal::Device] Failed to make new command queue.");
|
||||||
|
}
|
||||||
|
queue_map_.insert({index, q});
|
||||||
|
}
|
||||||
|
|
||||||
|
int Device::get_command_buffer_ops(int index) {
|
||||||
|
auto bit = buffer_map_.find(index);
|
||||||
|
return bit->second.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::increment_command_buffer_ops(int index) {
|
||||||
|
auto bit = buffer_map_.find(index);
|
||||||
|
bit->second.first++;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||||
|
auto bit = buffer_map_.find(index);
|
||||||
|
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::CommandBuffer* Device::new_command_buffer(int index) {
|
||||||
|
auto qit = queue_map_.find(index);
|
||||||
|
if (qit == queue_map_.end()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||||
|
|
||||||
|
if (!cb) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[metal::Device] Unable to create new command buffer");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment ref count so the buffer is not garbage collected
|
||||||
|
cb->retain();
|
||||||
|
|
||||||
|
return buffer_map_.insert({index, {0, cb}}).first->second.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::commit_command_buffer(int index) {
|
||||||
|
auto bit = buffer_map_.find(index);
|
||||||
|
bit->second.second->commit();
|
||||||
|
bit->second.second->release();
|
||||||
|
buffer_map_.erase(bit);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::end_encoding(int index) {
|
||||||
|
auto eit = encoder_map_.find(index);
|
||||||
|
if (eit != encoder_map_.end()) {
|
||||||
|
eit->second->endEncoding();
|
||||||
|
eit->second->release();
|
||||||
|
encoder_map_.erase(eit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||||
|
auto eit = encoder_map_.find(index);
|
||||||
|
if (eit == encoder_map_.end()) {
|
||||||
|
auto cb = get_command_buffer(index);
|
||||||
|
auto compute_encoder = cb->computeCommandEncoder();
|
||||||
|
// Increment ref count so the buffer is not garbage collected
|
||||||
|
compute_encoder->retain();
|
||||||
|
eit = encoder_map_.insert({index, compute_encoder}).first;
|
||||||
|
}
|
||||||
|
return eit->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ArgumentEncoder* Device::argument_encoder(
|
||||||
|
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
|
||||||
|
// NB array here is already autoreleased but the returned argument
|
||||||
|
// encoder is owned by the caller and must be released/autoreleased
|
||||||
|
NS::Array* arg_desc_arr = NS::Array::array(
|
||||||
|
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
|
||||||
|
return device_->newArgumentEncoder(arg_desc_arr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::register_library(
|
||||||
|
const std::string& lib_name,
|
||||||
|
const std::string& lib_path) {
|
||||||
|
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||||
|
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
|
||||||
|
library_map_.insert({lib_name, new_lib});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::register_library(
|
||||||
|
const std::string& lib_name,
|
||||||
|
const std::function<std::string(const std::string&)>& lib_path_func) {
|
||||||
|
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||||
|
std::string new_lib_path = lib_path_func(lib_name);
|
||||||
|
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
|
||||||
|
library_map_.insert({lib_name, new_lib});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* Device::get_kernel(
|
||||||
|
const std::string& name,
|
||||||
|
const std::string& lib_name /* = "mlx" */) {
|
||||||
|
// Look for cached kernel
|
||||||
|
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare new kernel
|
||||||
|
|
||||||
|
// Search for cached metal lib
|
||||||
|
MTL::Library* mtl_lib;
|
||||||
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||||
|
mtl_lib = it->second;
|
||||||
|
} else { // Look for metallib alongside library
|
||||||
|
register_library(lib_name);
|
||||||
|
mtl_lib = library_map_[lib_name];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pull kernel from library
|
||||||
|
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
|
||||||
|
auto mtl_function = mtl_lib->newFunction(ns_name);
|
||||||
|
|
||||||
|
// Compile kernel to compute pipeline
|
||||||
|
NS::Error* error = nullptr;
|
||||||
|
MTL::ComputePipelineState* kernel;
|
||||||
|
if (mtl_function) {
|
||||||
|
kernel = device_->newComputePipelineState(mtl_function, &error);
|
||||||
|
mtl_function->release();
|
||||||
|
}
|
||||||
|
if (!mtl_function || !kernel) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||||
|
if (error) {
|
||||||
|
msg << error->localizedDescription()->utf8String() << "\n";
|
||||||
|
}
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add kernel to cache
|
||||||
|
kernel_map_.insert({name, kernel});
|
||||||
|
return kernel;
|
||||||
|
}
|
||||||
|
|
||||||
|
Device& device(mlx::core::Device) {
|
||||||
|
return metal_device_;
|
||||||
|
}
|
||||||
|
|
||||||
|
NS::AutoreleasePool*& thread_autorelease_pool() {
|
||||||
|
static thread_local NS::AutoreleasePool* p =
|
||||||
|
NS::AutoreleasePool::alloc()->init();
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
void new_stream(Stream stream) {
|
||||||
|
thread_autorelease_pool();
|
||||||
|
if (stream.device == mlx::core::Device::gpu) {
|
||||||
|
device(stream.device).new_queue(stream.index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::metal
|
10
mlx/backend/metal/fft.cpp
Normal file
10
mlx/backend/metal/fft.cpp
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
auto& in = inputs[0];
|
||||||
|
throw std::runtime_error("[FFT] NYI for Metal backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
83
mlx/backend/metal/kernels/CMakeLists.txt
Normal file
83
mlx/backend/metal/kernels/CMakeLists.txt
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
set(
|
||||||
|
HEADERS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||||
|
)
|
||||||
|
|
||||||
|
set(
|
||||||
|
KERNELS
|
||||||
|
"arange"
|
||||||
|
"arg_reduce"
|
||||||
|
"binary"
|
||||||
|
"conv"
|
||||||
|
"copy"
|
||||||
|
"gemm"
|
||||||
|
"gemv"
|
||||||
|
"random"
|
||||||
|
"reduce"
|
||||||
|
"scan"
|
||||||
|
"softmax"
|
||||||
|
"sort"
|
||||||
|
"unary"
|
||||||
|
"indexing"
|
||||||
|
)
|
||||||
|
|
||||||
|
function(build_kernel KERNEL)
|
||||||
|
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||||
|
set(HEADERS_PADDED ${HEADERS})
|
||||||
|
if(${KERNEL} STREQUAL "gemm")
|
||||||
|
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h)
|
||||||
|
endif()
|
||||||
|
if(${KERNEL} STREQUAL "conv")
|
||||||
|
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h)
|
||||||
|
endif()
|
||||||
|
add_custom_command(
|
||||||
|
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
||||||
|
-fno-fast-math
|
||||||
|
-c ${SRCFILE}
|
||||||
|
-I${PROJECT_SOURCE_DIR}
|
||||||
|
-o ${KERNEL}.air
|
||||||
|
DEPENDS ${SRCFILE} ${HEADERS_PADDED}
|
||||||
|
OUTPUT ${KERNEL}.air
|
||||||
|
COMMENT "Building ${KERNEL}.air"
|
||||||
|
VERBATIM
|
||||||
|
)
|
||||||
|
endfunction(build_kernel)
|
||||||
|
|
||||||
|
foreach(KERNEL ${KERNELS})
|
||||||
|
build_kernel(${KERNEL})
|
||||||
|
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||||
|
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||||
|
DEPENDS ${KERNEL_AIR}
|
||||||
|
COMMENT "Building mlx.metallib"
|
||||||
|
VERBATIM
|
||||||
|
)
|
||||||
|
|
||||||
|
add_custom_target(
|
||||||
|
mlx-metallib
|
||||||
|
DEPENDS
|
||||||
|
${MLX_METAL_PATH}/mlx.metallib
|
||||||
|
)
|
||||||
|
|
||||||
|
add_dependencies(
|
||||||
|
mlx
|
||||||
|
mlx-metallib
|
||||||
|
)
|
||||||
|
|
||||||
|
# Install metallib
|
||||||
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
|
install(
|
||||||
|
FILES ${MLX_METAL_PATH}/mlx.metallib
|
||||||
|
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
COMPONENT metallib
|
||||||
|
)
|
17
mlx/backend/metal/kernels/conv_params.h
Normal file
17
mlx/backend/metal/kernels/conv_params.h
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
struct MLXConvParams {
|
||||||
|
const int N; // Batch size
|
||||||
|
const int C; // In channels
|
||||||
|
const int O; // Out channels
|
||||||
|
const int iS[NDIM]; // Input spatial dim
|
||||||
|
const int wS[NDIM]; // Weight spatial dim
|
||||||
|
const int oS[NDIM]; // Output spatial dim
|
||||||
|
const int str[NDIM]; // Kernel strides
|
||||||
|
const int pad[NDIM]; // Input padding
|
||||||
|
const int dil[NDIM]; // Kernel dilation
|
||||||
|
const size_t in_strides[NDIM + 2]; // In strides
|
||||||
|
const size_t wt_strides[NDIM + 2]; // Wt strides
|
||||||
|
const size_t out_strides[NDIM + 2]; // Out strides
|
||||||
|
};
|
253
mlx/backend/metal/kernels/indexing.metal
Normal file
253
mlx/backend/metal/kernels/indexing.metal
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
#include <metal_atomic>
|
||||||
|
#include <metal_texture>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduce.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
// Gather kernel
|
||||||
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename IdxT, int NIDX>
|
||||||
|
struct Indices {
|
||||||
|
const array<device IdxT*, NIDX> buffers [[id(0)]];
|
||||||
|
device int* shapes [[id(NIDX + 1)]];
|
||||||
|
device size_t* strides [[id(NIDX + 2)]];
|
||||||
|
const int ndim [[id(NIDX + 3)]];
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename IdxT>
|
||||||
|
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||||
|
return (idx < 0) ? idx + size : idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename IdxT, int NIDX>
|
||||||
|
[[kernel]] void gather(
|
||||||
|
const device T *src [[buffer(0)]],
|
||||||
|
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
||||||
|
device T *out [[buffer(2)]],
|
||||||
|
const device int *src_shape [[buffer(3)]],
|
||||||
|
const device size_t *src_strides [[buffer(4)]],
|
||||||
|
const device size_t& src_ndim [[buffer(5)]],
|
||||||
|
const device int *slice_sizes [[buffer(6)]],
|
||||||
|
const device size_t& slice_size [[buffer(7)]],
|
||||||
|
const device int *axes [[buffer(8)]],
|
||||||
|
uint gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
auto ind_idx = gid / slice_size;
|
||||||
|
auto ind_offset = gid % slice_size;
|
||||||
|
|
||||||
|
size_t src_idx = 0;
|
||||||
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
|
auto idx_loc = elem_to_loc(
|
||||||
|
ind_idx,
|
||||||
|
&indices.shapes[indices.ndim * i],
|
||||||
|
&indices.strides[indices.ndim * i],
|
||||||
|
indices.ndim);
|
||||||
|
auto ax = axes[i];
|
||||||
|
auto idx_val = offset_neg_idx(
|
||||||
|
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||||
|
src_idx += idx_val * src_strides[ax];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto src_offset = elem_to_loc(
|
||||||
|
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||||
|
out[gid] = src[src_idx + src_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_gather4(name, src_type, ind_type, nindex) \
|
||||||
|
template [[host_name("gather" name "_" #nindex)]] \
|
||||||
|
[[kernel]] void gather( \
|
||||||
|
const device src_type *src [[buffer(0)]], \
|
||||||
|
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||||
|
device src_type *out [[buffer(2)]], \
|
||||||
|
const device int *src_shape [[buffer(3)]], \
|
||||||
|
const device size_t *src_strides [[buffer(4)]], \
|
||||||
|
const device size_t& src_ndim [[buffer(5)]], \
|
||||||
|
const device int *slice_sizes [[buffer(6)]], \
|
||||||
|
const device size_t& slice_size [[buffer(7)]], \
|
||||||
|
const device int* axes [[buffer(8)]], \
|
||||||
|
uint gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// Special for case NIDX=0
|
||||||
|
instantiate_gather4("bool_", bool, bool, 0)
|
||||||
|
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||||
|
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||||
|
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||||
|
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||||
|
instantiate_gather4("int8", int8_t, bool, 0)
|
||||||
|
instantiate_gather4("int16", int16_t, bool, 0)
|
||||||
|
instantiate_gather4("int32", int32_t, bool, 0)
|
||||||
|
instantiate_gather4("int64", int64_t, bool, 0)
|
||||||
|
instantiate_gather4("float16", half, bool, 0)
|
||||||
|
instantiate_gather4("float32", float, bool, 0)
|
||||||
|
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||||
|
|
||||||
|
#define instantiate_gather3(name, src_type, ind_type) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||||
|
instantiate_gather4(name, src_type, ind_type, 10)
|
||||||
|
|
||||||
|
#define instantiate_gather(name, src_type) \
|
||||||
|
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||||
|
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||||
|
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||||
|
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||||
|
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||||
|
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||||
|
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||||
|
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||||
|
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||||
|
|
||||||
|
instantiate_gather(bool_, bool)
|
||||||
|
instantiate_gather(uint8, uint8_t)
|
||||||
|
instantiate_gather(uint16, uint16_t)
|
||||||
|
instantiate_gather(uint32, uint32_t)
|
||||||
|
instantiate_gather(uint64, uint64_t)
|
||||||
|
instantiate_gather(int8, int8_t)
|
||||||
|
instantiate_gather(int16, int16_t)
|
||||||
|
instantiate_gather(int32, int32_t)
|
||||||
|
instantiate_gather(int64, int64_t)
|
||||||
|
instantiate_gather(float16, half)
|
||||||
|
instantiate_gather(float32, float)
|
||||||
|
instantiate_gather(bfloat16, bfloat16_t)
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
// Scatter kernel
|
||||||
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||||
|
[[kernel]] void scatter(
|
||||||
|
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
|
||||||
|
const device T *updates [[buffer(1)]],
|
||||||
|
device mlx_atomic<T> *out [[buffer(2)]],
|
||||||
|
const device int *upd_shape [[buffer(3)]],
|
||||||
|
const device size_t *upd_strides [[buffer(4)]],
|
||||||
|
const device size_t& upd_ndim [[buffer(5)]],
|
||||||
|
const device size_t& upd_size [[buffer(6)]],
|
||||||
|
const device int *out_shape [[buffer(7)]],
|
||||||
|
const device size_t *out_strides [[buffer(8)]],
|
||||||
|
const device size_t& out_ndim [[buffer(9)]],
|
||||||
|
const device int* axes [[buffer(10)]],
|
||||||
|
uint gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
auto ind_idx = gid / upd_size;
|
||||||
|
auto ind_offset = gid % upd_size;
|
||||||
|
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
|
auto idx_loc = elem_to_loc(
|
||||||
|
ind_idx,
|
||||||
|
&indices.shapes[indices.ndim * i],
|
||||||
|
&indices.strides[indices.ndim * i],
|
||||||
|
indices.ndim);
|
||||||
|
auto ax = axes[i];
|
||||||
|
auto idx_val = offset_neg_idx(
|
||||||
|
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||||
|
out_idx += idx_val * out_strides[ax];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_offset = elem_to_loc(
|
||||||
|
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||||
|
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
||||||
|
|
||||||
|
op.atomic_update(out + out_idx + out_offset, updates[upd_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
|
||||||
|
template [[host_name("scatter" name "_" #nindex)]] \
|
||||||
|
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
|
||||||
|
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
|
||||||
|
const device type *updates [[buffer(1)]], \
|
||||||
|
device mlx_atomic<type> *out [[buffer(2)]], \
|
||||||
|
const device int *upd_shape [[buffer(3)]], \
|
||||||
|
const device size_t *upd_strides [[buffer(4)]], \
|
||||||
|
const device size_t& upd_ndim [[buffer(5)]], \
|
||||||
|
const device size_t& upd_size [[buffer(6)]], \
|
||||||
|
const device int *out_shape [[buffer(7)]], \
|
||||||
|
const device size_t *out_strides [[buffer(8)]], \
|
||||||
|
const device size_t& out_ndim [[buffer(9)]], \
|
||||||
|
const device int* axes [[buffer(10)]], \
|
||||||
|
uint gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
// Special case NINDEX=0
|
||||||
|
#define instantiate_scatter_nd0(name, type) \
|
||||||
|
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||||
|
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||||
|
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||||
|
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||||
|
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||||
|
|
||||||
|
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||||
|
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||||
|
|
||||||
|
#define instantiate_scatter2(name, type, ind_type) \
|
||||||
|
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||||
|
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||||
|
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||||
|
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||||
|
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||||
|
|
||||||
|
#define instantiate_scatter(name, type) \
|
||||||
|
instantiate_scatter2(#name "bool_", type, bool) \
|
||||||
|
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||||
|
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||||
|
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||||
|
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||||
|
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||||
|
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||||
|
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||||
|
instantiate_scatter2(#name "int64", type, int64_t)
|
||||||
|
|
||||||
|
// TODO uint64 and int64 unsupported
|
||||||
|
instantiate_scatter_nd0(bool_, bool)
|
||||||
|
instantiate_scatter_nd0(uint8, uint8_t)
|
||||||
|
instantiate_scatter_nd0(uint16, uint16_t)
|
||||||
|
instantiate_scatter_nd0(uint32, uint32_t)
|
||||||
|
instantiate_scatter_nd0(int8, int8_t)
|
||||||
|
instantiate_scatter_nd0(int16, int16_t)
|
||||||
|
instantiate_scatter_nd0(int32, int32_t)
|
||||||
|
instantiate_scatter_nd0(float16, half)
|
||||||
|
instantiate_scatter_nd0(float32, float)
|
||||||
|
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
||||||
|
|
||||||
|
instantiate_scatter(bool_, bool)
|
||||||
|
instantiate_scatter(uint8, uint8_t)
|
||||||
|
instantiate_scatter(uint16, uint16_t)
|
||||||
|
instantiate_scatter(uint32, uint32_t)
|
||||||
|
instantiate_scatter(int8, int8_t)
|
||||||
|
instantiate_scatter(int16, int16_t)
|
||||||
|
instantiate_scatter(int32, int32_t)
|
||||||
|
instantiate_scatter(float16, half)
|
||||||
|
instantiate_scatter(float32, float)
|
||||||
|
instantiate_scatter(bfloat16, bfloat16_t)
|
174
mlx/backend/metal/kernels/reduce.h
Normal file
174
mlx/backend/metal/kernels/reduce.h
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_atomic>
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/atomic.h"
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
union bool4_or_uint {
|
||||||
|
bool4 b;
|
||||||
|
unsigned int i;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct None {
|
||||||
|
template <typename T>
|
||||||
|
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||||
|
mlx_atomic_store_explicit(out, val, offset);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct And {
|
||||||
|
bool simd_reduce(bool val) {
|
||||||
|
return simd_all(val);
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr constant bool init = true;
|
||||||
|
|
||||||
|
void atomic_update(
|
||||||
|
device mlx_atomic<unsigned int>* out,
|
||||||
|
bool val,
|
||||||
|
int elem_idx,
|
||||||
|
int offset = 0) {
|
||||||
|
if (!val) {
|
||||||
|
bool4_or_uint update;
|
||||||
|
update.b = {true, true, true, true};
|
||||||
|
update.b[elem_idx] = false;
|
||||||
|
mlx_atomic_fetch_and_explicit(out, update.i, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
||||||
|
if (!val) {
|
||||||
|
mlx_atomic_store_explicit(out, val, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non atomic update
|
||||||
|
void update(device bool* out, bool val) {
|
||||||
|
*out &= val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operator
|
||||||
|
bool operator()(bool a, bool b) {
|
||||||
|
return a && b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Or {
|
||||||
|
bool simd_reduce(bool val) {
|
||||||
|
return simd_any(val);
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr constant bool init = false;
|
||||||
|
|
||||||
|
void atomic_update(
|
||||||
|
device mlx_atomic<unsigned int>* out,
|
||||||
|
bool val,
|
||||||
|
int elem_idx,
|
||||||
|
int offset = 0) {
|
||||||
|
if (val) {
|
||||||
|
bool4_or_uint update;
|
||||||
|
update.b = {false, false, false, false};
|
||||||
|
update.b[elem_idx] = true;
|
||||||
|
mlx_atomic_fetch_or_explicit(out, update.i, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
||||||
|
if (val) {
|
||||||
|
mlx_atomic_store_explicit(out, val, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non atomic update
|
||||||
|
void update(device bool* out, bool val) {
|
||||||
|
*out |= val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operator
|
||||||
|
bool operator()(bool a, bool b) {
|
||||||
|
return a || b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct Sum {
|
||||||
|
template <typename T>
|
||||||
|
T simd_reduce(T val) {
|
||||||
|
return simd_sum(val);
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr constant U init = U(0);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||||
|
mlx_atomic_fetch_add_explicit(out, val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operator
|
||||||
|
U operator()(U a, U b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct Prod {
|
||||||
|
template <typename T>
|
||||||
|
T simd_reduce(T val) {
|
||||||
|
return simd_product(val);
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr constant U init = U(1);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||||
|
mlx_atomic_fetch_mul_explicit(out, val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operator
|
||||||
|
U operator()(U a, U b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct Min {
|
||||||
|
template <typename T>
|
||||||
|
T simd_reduce(T val) {
|
||||||
|
return simd_min(val);
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr constant U init = Limits<U>::max;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||||
|
mlx_atomic_fetch_min_explicit(out, val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operator
|
||||||
|
U operator()(U a, U b) {
|
||||||
|
return a < b ? a : b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct Max {
|
||||||
|
template <typename T>
|
||||||
|
T simd_reduce(T val) {
|
||||||
|
return simd_max(val);
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr constant U init = Limits<U>::min;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||||
|
mlx_atomic_fetch_max_explicit(out, val, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operator
|
||||||
|
U operator()(U a, U b) {
|
||||||
|
return a > b ? a : b;
|
||||||
|
}
|
||||||
|
};
|
492
mlx/backend/metal/kernels/scan.metal
Normal file
492
mlx/backend/metal/kernels/scan.metal
Normal file
@ -0,0 +1,492 @@
|
|||||||
|
#include <metal_math>
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct CumSum {
|
||||||
|
static constexpr constant U init = static_cast<U>(0);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
U operator()(U a, T b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_scan(U x) {
|
||||||
|
return simd_prefix_inclusive_sum(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_exclusive_scan(U x) {
|
||||||
|
return simd_prefix_exclusive_sum(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct CumProd {
|
||||||
|
static constexpr constant U init = static_cast<U>(1.0f);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
U operator()(U a, T b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_scan(U x) {
|
||||||
|
return simd_prefix_inclusive_product(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_exclusive_scan(U x) {
|
||||||
|
return simd_prefix_exclusive_product(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CumProd<bool> {
|
||||||
|
static constexpr constant bool init = true;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(bool a, T b) {
|
||||||
|
return a & static_cast<bool>(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool simd_scan(bool x) {
|
||||||
|
for (int i=1; i<=16; i*=2) {
|
||||||
|
bool other = simd_shuffle_up(x, i);
|
||||||
|
x &= other;
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool simd_exclusive_scan(bool x) {
|
||||||
|
x = simd_scan(x);
|
||||||
|
return simd_shuffle_and_fill_up(x, init, 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct CumMax {
|
||||||
|
static constexpr constant U init = Limits<U>::min;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
U operator()(U a, T b) {
|
||||||
|
return (a >= b) ? a : b;
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_scan(U x) {
|
||||||
|
for (int i=1; i<=16; i*=2) {
|
||||||
|
U other = simd_shuffle_up(x, i);
|
||||||
|
x = (x >= other) ? x : other;
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_exclusive_scan(U x) {
|
||||||
|
x = simd_scan(x);
|
||||||
|
return simd_shuffle_and_fill_up(x, init, 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct CumMin {
|
||||||
|
static constexpr constant U init = Limits<U>::max;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
U operator()(U a, T b) {
|
||||||
|
return (a <= b) ? a : b;
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_scan(U x) {
|
||||||
|
for (int i=1; i<=16; i*=2) {
|
||||||
|
U other = simd_shuffle_up(x, i);
|
||||||
|
x = (x <= other) ? x : other;
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
U simd_exclusive_scan(U x) {
|
||||||
|
x = simd_scan(x);
|
||||||
|
return simd_shuffle_and_fill_up(x, init, 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, int N_READS, bool reverse>
|
||||||
|
inline void load_unsafe(U values[N_READS], const device T * input) {
|
||||||
|
if (reverse) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
values[N_READS-i-1] = input[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
values[i] = input[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, int N_READS, bool reverse>
|
||||||
|
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
|
||||||
|
if (reverse) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
values[i] = (start + i < total) ? input[i] : init;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int N_READS, bool reverse>
|
||||||
|
inline void write_unsafe(U values[N_READS], device U * out) {
|
||||||
|
if (reverse) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
out[i] = values[N_READS-i-1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
out[i] = values[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int N_READS, bool reverse>
|
||||||
|
inline void write_safe(U values[N_READS], device U * out, int start, int total) {
|
||||||
|
if (reverse) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
if (start + N_READS - i - 1 < total) {
|
||||||
|
out[i] = values[N_READS-i-1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
if (start + i < total) {
|
||||||
|
out[i] = values[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
||||||
|
[[kernel]] void contiguous_scan(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
device U* out [[buffer(1)]],
|
||||||
|
const constant size_t & axis_size [[buffer(2)]],
|
||||||
|
uint gid [[thread_position_in_grid]],
|
||||||
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
|
uint lsize [[threads_per_threadgroup]],
|
||||||
|
uint simd_size [[threads_per_simdgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
// Position the pointers
|
||||||
|
in += (gid / lsize) * axis_size;
|
||||||
|
out += (gid / lsize) * axis_size;
|
||||||
|
|
||||||
|
// Compute the number of simd_groups
|
||||||
|
uint simd_groups = lsize / simd_size;
|
||||||
|
|
||||||
|
// Allocate memory
|
||||||
|
U prefix = Op::init;
|
||||||
|
U values[N_READS];
|
||||||
|
threadgroup U simdgroup_sums[32];
|
||||||
|
|
||||||
|
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
|
||||||
|
// Read block
|
||||||
|
// Compute inclusive scan of the block
|
||||||
|
// Compute inclusive scan per thread
|
||||||
|
// Compute exclusive scan of thread sums in simdgroup
|
||||||
|
// Write simdgroup sums in SM
|
||||||
|
// Compute exclusive scan of simdgroup sums
|
||||||
|
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
|
||||||
|
// Write block
|
||||||
|
|
||||||
|
for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||||
|
// Compute the block offset
|
||||||
|
uint offset = r*lsize*N_READS + lid*N_READS;
|
||||||
|
|
||||||
|
// Read the values
|
||||||
|
if (reverse) {
|
||||||
|
if ((offset + N_READS) < axis_size) {
|
||||||
|
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
|
||||||
|
} else {
|
||||||
|
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ((offset + N_READS) < axis_size) {
|
||||||
|
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
||||||
|
} else {
|
||||||
|
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute an inclusive scan per thread
|
||||||
|
for (int i=1; i<N_READS; i++) {
|
||||||
|
values[i] = op(values[i], values[i-1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute exclusive scan of thread sums
|
||||||
|
U prev_thread = op.simd_exclusive_scan(values[N_READS-1]);
|
||||||
|
|
||||||
|
// Write simdgroup_sums to SM
|
||||||
|
if (simd_lane_id == simd_size - 1) {
|
||||||
|
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Compute exclusive scan of simdgroup_sums
|
||||||
|
if (simd_group_id == 0) {
|
||||||
|
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
|
||||||
|
simdgroup_sums[simd_lane_id] = prev_simdgroup;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Compute the output
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
values[i] = op(values[i], prefix);
|
||||||
|
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
|
||||||
|
values[i] = op(values[i], prev_thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the values
|
||||||
|
if (reverse) {
|
||||||
|
if (inclusive) {
|
||||||
|
if ((offset + N_READS) < axis_size) {
|
||||||
|
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
|
||||||
|
} else {
|
||||||
|
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (lid == 0 && offset == 0) {
|
||||||
|
out[axis_size-1] = Op::init;
|
||||||
|
}
|
||||||
|
if ((offset + N_READS + 1) < axis_size) {
|
||||||
|
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
|
||||||
|
} else {
|
||||||
|
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (inclusive) {
|
||||||
|
if ((offset + N_READS) < axis_size) {
|
||||||
|
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
||||||
|
} else {
|
||||||
|
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (lid == 0 && offset == 0) {
|
||||||
|
out[0] = Op::init;
|
||||||
|
}
|
||||||
|
if ((offset + N_READS + 1) < axis_size) {
|
||||||
|
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
||||||
|
} else {
|
||||||
|
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Share the prefix
|
||||||
|
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
||||||
|
simdgroup_sums[0] = values[N_READS-1];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
prefix = simdgroup_sums[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
||||||
|
[[kernel]] void strided_scan(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
device U* out [[buffer(1)]],
|
||||||
|
const constant size_t & axis_size [[buffer(2)]],
|
||||||
|
const constant size_t & stride [[buffer(3)]],
|
||||||
|
uint2 gid [[threadgroup_position_in_grid]],
|
||||||
|
uint2 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint2 lsize [[threads_per_threadgroup]],
|
||||||
|
uint simd_size [[threads_per_simdgroup]]) {
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
// Allocate memory
|
||||||
|
threadgroup U read_buffer[N_READS*32*32 + N_READS*32];
|
||||||
|
U values[N_READS];
|
||||||
|
U prefix[N_READS];
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
prefix[i] = Op::init;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute offsets
|
||||||
|
int offset = gid.y * axis_size * stride;
|
||||||
|
int global_index_x = gid.x * lsize.y * N_READS;
|
||||||
|
|
||||||
|
for (uint j=0; j<axis_size; j+=simd_size) {
|
||||||
|
// Calculate the indices for the current thread
|
||||||
|
uint index_y = j + lid.y;
|
||||||
|
uint check_index_y = index_y;
|
||||||
|
uint index_x = global_index_x + lid.x * N_READS;
|
||||||
|
if (reverse) {
|
||||||
|
index_y = axis_size - 1 - index_y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read in SM
|
||||||
|
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||||
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
||||||
|
} else {
|
||||||
|
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Read strided into registers
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||||
|
}
|
||||||
|
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
|
||||||
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Perform the scan
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
values[i] = op.simd_scan(values[i]);
|
||||||
|
values[i] = op(values[i], prefix[i]);
|
||||||
|
prefix[i] = simd_shuffle(values[i], simd_size-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to SM
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Write to device memory
|
||||||
|
if (!inclusive) {
|
||||||
|
if (check_index_y == 0) {
|
||||||
|
if ((index_x + N_READS) < stride) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
if ((index_x + i) < stride) {
|
||||||
|
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (reverse) {
|
||||||
|
index_y -= 1;
|
||||||
|
check_index_y += 1;
|
||||||
|
} else {
|
||||||
|
index_y += 1;
|
||||||
|
check_index_y += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||||
|
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
||||||
|
template [[host_name("contiguous_scan_" #name)]] \
|
||||||
|
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||||
|
const device itype* in [[buffer(0)]], \
|
||||||
|
device otype* out [[buffer(1)]], \
|
||||||
|
const constant size_t & axis_size [[buffer(2)]], \
|
||||||
|
uint gid [[thread_position_in_grid]], \
|
||||||
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint lsize [[threads_per_threadgroup]], \
|
||||||
|
uint simd_size [[threads_per_simdgroup]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
||||||
|
template [[host_name("strided_scan_" #name)]] \
|
||||||
|
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||||
|
const device itype* in [[buffer(0)]], \
|
||||||
|
device otype* out [[buffer(1)]], \
|
||||||
|
const constant size_t & axis_size [[buffer(2)]], \
|
||||||
|
const constant size_t & stride [[buffer(3)]], \
|
||||||
|
uint2 gid [[thread_position_in_grid]], \
|
||||||
|
uint2 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint2 lsize [[threads_per_threadgroup]], \
|
||||||
|
uint simd_size [[threads_per_simdgroup]]);
|
||||||
|
|
||||||
|
|
||||||
|
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
||||||
|
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||||
|
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||||
|
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||||
|
instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \
|
||||||
|
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||||
|
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||||
|
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||||
|
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
|
||||||
|
|
||||||
|
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
||||||
|
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
||||||
|
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
||||||
|
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)
|
||||||
|
//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
|
||||||
|
instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4)
|
||||||
|
instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4)
|
||||||
|
instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4)
|
||||||
|
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
|
||||||
|
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
|
||||||
|
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
|
||||||
|
//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||||
|
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
|
||||||
|
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
|
||||||
|
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
|
||||||
|
instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4)
|
||||||
|
instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4)
|
||||||
|
//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
|
||||||
|
instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4)
|
||||||
|
instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4)
|
||||||
|
instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4)
|
||||||
|
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
|
||||||
|
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
|
||||||
|
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
|
||||||
|
//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||||
|
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
|
||||||
|
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
|
||||||
|
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
|
||||||
|
instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4)
|
||||||
|
instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4)
|
||||||
|
//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
|
||||||
|
instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4)
|
||||||
|
instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4)
|
||||||
|
instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4)
|
||||||
|
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
|
||||||
|
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
|
||||||
|
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
|
||||||
|
//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||||
|
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
|
||||||
|
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
|
||||||
|
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
|
||||||
|
instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4)
|
||||||
|
instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4)
|
||||||
|
//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
|
||||||
|
instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4)
|
||||||
|
instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4)
|
||||||
|
instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4)
|
||||||
|
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
|
||||||
|
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||||
|
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||||
|
//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||||
|
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
|
88
mlx/backend/metal/metal.cpp
Normal file
88
mlx/backend/metal/metal.cpp
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <future>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
int max_ops_per_buffer() {
|
||||||
|
auto get_val = []() {
|
||||||
|
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||||
|
return atoi(buff_str);
|
||||||
|
} else {
|
||||||
|
return 10;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
static int max_ops_per_buffer_ = get_val();
|
||||||
|
return max_ops_per_buffer_;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||||
|
|
||||||
|
MTL::CommandBuffer* increment_command_buffer(Stream s) {
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
if (command_buffer == nullptr ||
|
||||||
|
d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
||||||
|
if (command_buffer != nullptr) {
|
||||||
|
d.end_encoding(s.index);
|
||||||
|
scheduler::notify_new_task(s);
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); });
|
||||||
|
d.commit_command_buffer(s.index);
|
||||||
|
}
|
||||||
|
command_buffer = d.new_command_buffer(s.index);
|
||||||
|
}
|
||||||
|
d.increment_command_buffer_ops(s.index);
|
||||||
|
return command_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<void()> make_task(
|
||||||
|
array& arr,
|
||||||
|
std::vector<std::shared_future<void>> deps,
|
||||||
|
std::shared_ptr<std::promise<void>> p,
|
||||||
|
bool retain_graph) {
|
||||||
|
auto task =
|
||||||
|
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||||
|
for (auto& d : deps) {
|
||||||
|
d.wait();
|
||||||
|
}
|
||||||
|
auto s = arr.primitive().stream();
|
||||||
|
auto command_buffer = increment_command_buffer(s);
|
||||||
|
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||||
|
if (p) {
|
||||||
|
metal::device(s.device).end_encoding(s.index);
|
||||||
|
scheduler::notify_new_task(s);
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[retain_graph, s, arr, p = std::move(p)](
|
||||||
|
MTL::CommandBuffer*) mutable {
|
||||||
|
if (!retain_graph) {
|
||||||
|
arr.detach();
|
||||||
|
}
|
||||||
|
p->set_value();
|
||||||
|
// Signal this thread to clear the pool on a synchroniztion.
|
||||||
|
scheduler::enqueue(s, []() {
|
||||||
|
thread_autorelease_pool()->release();
|
||||||
|
thread_autorelease_pool() =
|
||||||
|
NS::AutoreleasePool::alloc()->init();
|
||||||
|
});
|
||||||
|
scheduler::notify_task_completion(s);
|
||||||
|
});
|
||||||
|
metal::device(s.device).commit_command_buffer(s.index);
|
||||||
|
} else {
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||||
|
if (!retain_graph) {
|
||||||
|
arr.detach();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return task;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::metal
|
28
mlx/backend/metal/metal.h
Normal file
28
mlx/backend/metal/metal.h
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <future>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
constexpr bool is_available() {
|
||||||
|
#ifdef _METAL_
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void new_stream(Stream stream);
|
||||||
|
|
||||||
|
std::function<void()> make_task(
|
||||||
|
array& arr,
|
||||||
|
std::vector<std::shared_future<void>> deps,
|
||||||
|
std::shared_ptr<std::promise<void>> p,
|
||||||
|
bool retain_graph);
|
||||||
|
|
||||||
|
} // namespace mlx::core::metal
|
82
mlx/backend/metal/softmax.cpp
Normal file
82
mlx/backend/metal/softmax.cpp
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
if (!is_floating_point(out.dtype())) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[softmax] Does not support non-floating point types.");
|
||||||
|
}
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto check_input = [&copies, &s](const array& x) {
|
||||||
|
if (x.strides()[x.ndim() - 1] == 1) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
copies.push_back(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const array& in = check_input(inputs[0]);
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
|
||||||
|
int axis_size = in.shape().back();
|
||||||
|
int n_rows = in.data_size() / axis_size;
|
||||||
|
|
||||||
|
const int simd_size = 32;
|
||||||
|
const int n_reads = SOFTMAX_N_READS;
|
||||||
|
const int looped_limit = SOFTMAX_LOOPED_LIMIT;
|
||||||
|
std::string op_name = "softmax_";
|
||||||
|
if (axis_size > looped_limit) {
|
||||||
|
op_name += "looped_";
|
||||||
|
}
|
||||||
|
op_name += type_to_name(out);
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
{
|
||||||
|
auto kernel = d.get_kernel(op_name);
|
||||||
|
|
||||||
|
MTL::Size grid_dims, group_dims;
|
||||||
|
if (axis_size <= looped_limit) {
|
||||||
|
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||||
|
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||||
|
size_t threadgroup_size = simd_size * simds_needed;
|
||||||
|
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
size_t n_threads = n_rows * threadgroup_size;
|
||||||
|
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||||
|
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||||
|
} else {
|
||||||
|
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
size_t n_threads = n_rows * threadgroup_size;
|
||||||
|
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||||
|
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||||
|
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
|
||||||
|
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
336
mlx/backend/metal/sort.cpp
Normal file
336
mlx/backend/metal/sort.cpp
Normal file
@ -0,0 +1,336 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <bool ARGSORT>
|
||||||
|
void single_block_sort(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int axis,
|
||||||
|
int bn,
|
||||||
|
int tn) {
|
||||||
|
// Prepare shapes
|
||||||
|
int n_rows = in.size() / in.shape(axis);
|
||||||
|
|
||||||
|
std::vector<size_t> nc_str = in.strides();
|
||||||
|
nc_str.erase(nc_str.begin() + axis);
|
||||||
|
|
||||||
|
std::vector<int> nc_shape = in.shape();
|
||||||
|
nc_shape.erase(nc_shape.begin() + axis);
|
||||||
|
|
||||||
|
int nc_dim = nc_shape.size();
|
||||||
|
|
||||||
|
int size_sorted_axis = in.shape(axis);
|
||||||
|
int stride_sorted_axis = in.strides()[axis];
|
||||||
|
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
|
||||||
|
|
||||||
|
// Check if remaining strides are contiguous
|
||||||
|
bool contiguous_write = true;
|
||||||
|
if (axis != in.ndim() - 1 && axis != 0) {
|
||||||
|
for (int i = 0; i < nc_str.size() - 1; ++i) {
|
||||||
|
size_t expected = nc_str[i + 1] * nc_str[i + 1];
|
||||||
|
contiguous_write &= (nc_str[i] == expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare kernel name
|
||||||
|
std::ostringstream kname;
|
||||||
|
if (ARGSORT) {
|
||||||
|
kname << "arg_";
|
||||||
|
}
|
||||||
|
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
|
||||||
|
<< "_bn" << bn << "_tn" << tn;
|
||||||
|
|
||||||
|
if (!contiguous_write) {
|
||||||
|
kname << "_nc";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare command encoder
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Set inputs
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
||||||
|
|
||||||
|
if (contiguous_write) {
|
||||||
|
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
|
||||||
|
} else {
|
||||||
|
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool ARGSORT>
|
||||||
|
void multi_block_sort(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int axis,
|
||||||
|
int bn,
|
||||||
|
int tn,
|
||||||
|
int n_blocks) {
|
||||||
|
// Prepare shapes
|
||||||
|
int n_rows = in.size() / in.shape(axis);
|
||||||
|
|
||||||
|
std::vector<size_t> nc_str = in.strides();
|
||||||
|
nc_str.erase(nc_str.begin() + axis);
|
||||||
|
|
||||||
|
std::vector<int> nc_shape = in.shape();
|
||||||
|
nc_shape.erase(nc_shape.begin() + axis);
|
||||||
|
|
||||||
|
int nc_dim = nc_shape.size();
|
||||||
|
|
||||||
|
int size_sorted_axis = in.shape(axis);
|
||||||
|
int stride_sorted_axis = in.strides()[axis];
|
||||||
|
|
||||||
|
// Make temporary copies
|
||||||
|
array dev_vals_0({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
|
||||||
|
array dev_vals_1({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
|
||||||
|
|
||||||
|
array dev_idxs_0({n_rows, size_sorted_axis}, uint32, nullptr, {});
|
||||||
|
array dev_idxs_1({n_rows, size_sorted_axis}, uint32, nullptr, {});
|
||||||
|
|
||||||
|
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
|
||||||
|
|
||||||
|
// Do allocations
|
||||||
|
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
|
||||||
|
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
|
||||||
|
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
|
||||||
|
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
|
||||||
|
block_partitions.set_data(
|
||||||
|
allocator::malloc_or_wait(block_partitions.nbytes()));
|
||||||
|
|
||||||
|
std::vector<array> copies = {
|
||||||
|
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||||
|
|
||||||
|
// Prepare command encoder
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|
||||||
|
// Do blockwise sort
|
||||||
|
{
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
|
||||||
|
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
|
||||||
|
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, dev_vals_0, 1);
|
||||||
|
set_array_buffer(compute_encoder, dev_idxs_0, 2);
|
||||||
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||||
|
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do merges
|
||||||
|
bool ping = false;
|
||||||
|
array dev_vals_in = dev_vals_0;
|
||||||
|
array dev_idxs_in = dev_idxs_0;
|
||||||
|
array dev_vals_out = dev_vals_1;
|
||||||
|
array dev_idxs_out = dev_idxs_1;
|
||||||
|
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
|
||||||
|
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
||||||
|
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
||||||
|
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
|
||||||
|
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
||||||
|
ping = !ping;
|
||||||
|
|
||||||
|
// Do partiton
|
||||||
|
{
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
|
||||||
|
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||||
|
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||||
|
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||||
|
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||||
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do merge
|
||||||
|
{
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
|
||||||
|
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||||
|
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||||
|
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||||
|
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||||
|
set_array_buffer(compute_encoder, dev_vals_out, 3);
|
||||||
|
set_array_buffer(compute_encoder, dev_idxs_out, 4);
|
||||||
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
|
||||||
|
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy outputs with appropriate strides
|
||||||
|
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
|
||||||
|
|
||||||
|
if (axis == strided_out_arr.ndim() - 1) {
|
||||||
|
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
|
||||||
|
} else {
|
||||||
|
std::vector<int> strided_out_shape = strided_out_arr.shape();
|
||||||
|
std::vector<size_t> strided_out_str = strided_out_arr.strides();
|
||||||
|
|
||||||
|
int out_axis_shape = strided_out_shape[axis];
|
||||||
|
int out_axis_str = strided_out_str[axis];
|
||||||
|
|
||||||
|
strided_out_shape.erase(strided_out_shape.begin() + axis);
|
||||||
|
strided_out_str.erase(strided_out_str.begin() + axis);
|
||||||
|
|
||||||
|
strided_out_shape.push_back(out_axis_shape);
|
||||||
|
strided_out_str.push_back(out_axis_str);
|
||||||
|
|
||||||
|
array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {});
|
||||||
|
strided_out_slice.copy_shared_buffer(
|
||||||
|
strided_out_arr,
|
||||||
|
strided_out_str,
|
||||||
|
strided_out_arr.flags(),
|
||||||
|
strided_out_arr.size(),
|
||||||
|
0);
|
||||||
|
|
||||||
|
copy_gpu_inplace(strided_out_slice, out, CopyType::General, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear copies
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool ARGSORT>
|
||||||
|
void gpu_merge_sort(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int axis_) {
|
||||||
|
// Get size info
|
||||||
|
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
|
||||||
|
int size_sorted_axis = in.shape(axis);
|
||||||
|
|
||||||
|
// Get kernel size
|
||||||
|
int tn = 8;
|
||||||
|
int bn = 128;
|
||||||
|
int potential_bn = (size_sorted_axis + tn - 1) / tn;
|
||||||
|
|
||||||
|
if (potential_bn > 256) {
|
||||||
|
bn = 512;
|
||||||
|
} else if (potential_bn > 128) {
|
||||||
|
bn = 256;
|
||||||
|
} else {
|
||||||
|
bn = 128;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bn == 512 && size_of(in.dtype()) > 4) {
|
||||||
|
bn = 256;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_per_block = bn * tn;
|
||||||
|
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
|
||||||
|
|
||||||
|
if (n_blocks > 1) {
|
||||||
|
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
|
||||||
|
} else {
|
||||||
|
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
// We direct arg partition to sort for now
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
// We direct partition to sort for now
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
7
mlx/backend/no_metal/CMakeLists.txt
Normal file
7
mlx/backend/no_metal/CMakeLists.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
)
|
77
mlx/backend/no_metal/primitives.cpp
Normal file
77
mlx/backend/no_metal/primitives.cpp
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#define NO_GPU(func) \
|
||||||
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
|
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
NO_GPU(Abs)
|
||||||
|
NO_GPU(Add)
|
||||||
|
NO_GPU(Arange)
|
||||||
|
NO_GPU(ArcCos)
|
||||||
|
NO_GPU(ArcCosh)
|
||||||
|
NO_GPU(ArcSin)
|
||||||
|
NO_GPU(ArcSinh)
|
||||||
|
NO_GPU(ArcTan)
|
||||||
|
NO_GPU(ArcTanh)
|
||||||
|
NO_GPU(ArgPartition)
|
||||||
|
NO_GPU(ArgReduce)
|
||||||
|
NO_GPU(ArgSort)
|
||||||
|
NO_GPU(AsType)
|
||||||
|
NO_GPU(AsStrided)
|
||||||
|
NO_GPU(Broadcast)
|
||||||
|
NO_GPU(Concatenate)
|
||||||
|
NO_GPU(Convolution)
|
||||||
|
NO_GPU(Copy)
|
||||||
|
NO_GPU(Cos)
|
||||||
|
NO_GPU(Cosh)
|
||||||
|
NO_GPU(Divide)
|
||||||
|
NO_GPU(Equal)
|
||||||
|
NO_GPU(Erf)
|
||||||
|
NO_GPU(ErfInv)
|
||||||
|
NO_GPU(Exp)
|
||||||
|
NO_GPU(FFT)
|
||||||
|
NO_GPU(Full)
|
||||||
|
NO_GPU(Gather)
|
||||||
|
NO_GPU(Greater)
|
||||||
|
NO_GPU(GreaterEqual)
|
||||||
|
NO_GPU(Less)
|
||||||
|
NO_GPU(LessEqual)
|
||||||
|
NO_GPU(Load)
|
||||||
|
NO_GPU(Log)
|
||||||
|
NO_GPU(Log1p)
|
||||||
|
NO_GPU(LogicalNot)
|
||||||
|
NO_GPU(LogAddExp)
|
||||||
|
NO_GPU(Matmul)
|
||||||
|
NO_GPU(Maximum)
|
||||||
|
NO_GPU(Minimum)
|
||||||
|
NO_GPU(Multiply)
|
||||||
|
NO_GPU(Negative)
|
||||||
|
NO_GPU(NotEqual)
|
||||||
|
NO_GPU(Pad)
|
||||||
|
NO_GPU(Partition)
|
||||||
|
NO_GPU(Power)
|
||||||
|
NO_GPU(RandomBits)
|
||||||
|
NO_GPU(Reduce)
|
||||||
|
NO_GPU(Reshape)
|
||||||
|
NO_GPU(Scan)
|
||||||
|
NO_GPU(Scatter)
|
||||||
|
NO_GPU(Sigmoid)
|
||||||
|
NO_GPU(Sign)
|
||||||
|
NO_GPU(Sin)
|
||||||
|
NO_GPU(Sinh)
|
||||||
|
NO_GPU(Slice)
|
||||||
|
NO_GPU(Softmax)
|
||||||
|
NO_GPU(Sort)
|
||||||
|
NO_GPU(Square)
|
||||||
|
NO_GPU(Sqrt)
|
||||||
|
NO_GPU(StopGradient)
|
||||||
|
NO_GPU(Subtract)
|
||||||
|
NO_GPU(Tan)
|
||||||
|
NO_GPU(Tanh)
|
||||||
|
NO_GPU(Transpose)
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
29
mlx/device.cpp
Normal file
29
mlx/device.cpp
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#include "mlx/device.h"
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
static Device default_device_{
|
||||||
|
metal::is_available() ? Device::gpu : Device::cpu};
|
||||||
|
|
||||||
|
const Device& default_device() {
|
||||||
|
return default_device_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_default_device(const Device& d) {
|
||||||
|
if (!metal::is_available() && d == Device::gpu) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[set_default_device] Cannot set gpu device without gpu backend.");
|
||||||
|
}
|
||||||
|
default_device_ = d;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const Device& lhs, const Device& rhs) {
|
||||||
|
return lhs.type == rhs.type && lhs.index == rhs.index;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(const Device& lhs, const Device& rhs) {
|
||||||
|
return !(lhs == rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
99
mlx/dtype.h
Normal file
99
mlx/dtype.h
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <ostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mlx/types/complex.h"
|
||||||
|
#include "mlx/types/half_types.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct Dtype {
|
||||||
|
enum class Val {
|
||||||
|
bool_,
|
||||||
|
uint8,
|
||||||
|
uint16,
|
||||||
|
uint32,
|
||||||
|
uint64,
|
||||||
|
int8,
|
||||||
|
int16,
|
||||||
|
int32,
|
||||||
|
int64,
|
||||||
|
float16,
|
||||||
|
float32,
|
||||||
|
bfloat16,
|
||||||
|
complex64,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class Kind {
|
||||||
|
b, /* bool */
|
||||||
|
u, /* unsigned int */
|
||||||
|
i, /* signed int */
|
||||||
|
f, /* float */
|
||||||
|
c, /* complex */
|
||||||
|
V, /* void - used for brain float */
|
||||||
|
};
|
||||||
|
|
||||||
|
Val val;
|
||||||
|
const uint8_t size;
|
||||||
|
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
||||||
|
constexpr operator Val() const {
|
||||||
|
return val;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
inline bool is_available(const Dtype& dtype) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||||
|
|
||||||
|
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
|
||||||
|
static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
|
||||||
|
static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
|
||||||
|
static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
|
||||||
|
|
||||||
|
static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
|
||||||
|
static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
|
||||||
|
static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
|
||||||
|
static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
|
||||||
|
|
||||||
|
static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
|
||||||
|
static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
||||||
|
static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
||||||
|
static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
||||||
|
|
||||||
|
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
||||||
|
|
||||||
|
inline uint8_t size_of(const Dtype& t) {
|
||||||
|
return t.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
Dtype::Kind kindof(const Dtype& t);
|
||||||
|
|
||||||
|
inline bool is_unsigned(const Dtype& t) {
|
||||||
|
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool is_floating_point(const Dtype& t) {
|
||||||
|
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
|
||||||
|
kindof(t) == Dtype::Kind::c;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool is_integral(const Dtype& t) {
|
||||||
|
return !(is_floating_point(t));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct TypeToDtype {
|
||||||
|
operator Dtype();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Array protocol typestring for Dtype
|
||||||
|
std::string dtype_to_array_protocol(const Dtype& t);
|
||||||
|
// Dtype from array protocol type string
|
||||||
|
Dtype dtype_from_array_protocol(const std::string& t);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
190
mlx/fft.cpp
Normal file
190
mlx/fft.cpp
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
#include <numeric>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
|
#include "mlx/fft.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::fft {
|
||||||
|
|
||||||
|
array fft_impl(
|
||||||
|
const array& a,
|
||||||
|
std::vector<int> n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool real,
|
||||||
|
bool inverse,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (a.ndim() < 1) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[fftn] Requires array with at least one dimension.");
|
||||||
|
}
|
||||||
|
if (n.size() != axes.size()) {
|
||||||
|
throw std::invalid_argument("[fftn] Shape and axes have different sizes.");
|
||||||
|
}
|
||||||
|
if (axes.empty()) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> valid_axes;
|
||||||
|
for (int ax : axes) {
|
||||||
|
valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);
|
||||||
|
}
|
||||||
|
std::set<int> unique_axes(valid_axes.begin(), valid_axes.end());
|
||||||
|
if (unique_axes.size() != axes.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[fftn] Duplicated axis received " << axes;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (*unique_axes.begin() < 0 || *unique_axes.rbegin() >= a.ndim()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[fftn] Invalid axis received for array with " << a.ndim()
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// In the following shape manipulations there are three cases to consdier:
|
||||||
|
// 1. In a complex to complex transform (fftn / ifftn) the output
|
||||||
|
// and input shapes are the same.
|
||||||
|
// 2. In a real to complex transform (rfftn) n specifies the input dims
|
||||||
|
// and the output dims are n[i] / 2 + 1
|
||||||
|
// 3 In a complex to real transform (irfftn) n specifies the output dims
|
||||||
|
// and the input dims are n[i] / 2 + 1
|
||||||
|
|
||||||
|
if (std::any_of(n.begin(), n.end(), [](auto i) { return i <= 0; })) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[fftn] Invalid FFT output size requested " << n;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> in_shape = a.shape();
|
||||||
|
for (int i = 0; i < valid_axes.size(); ++i) {
|
||||||
|
in_shape[valid_axes[i]] = n[i];
|
||||||
|
}
|
||||||
|
if (real && inverse) {
|
||||||
|
in_shape[valid_axes.back()] = n.back() / 2 + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool any_greater = false;
|
||||||
|
bool any_less = false;
|
||||||
|
for (int i = 0; i < in_shape.size(); ++i) {
|
||||||
|
any_greater |= in_shape[i] > a.shape()[i];
|
||||||
|
any_less |= in_shape[i] < a.shape()[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto in = a;
|
||||||
|
if (any_less) {
|
||||||
|
in = slice(in, std::vector<int>(in.ndim(), 0), in_shape, s);
|
||||||
|
}
|
||||||
|
if (any_greater) {
|
||||||
|
// Pad with zeros
|
||||||
|
auto tmp = zeros(in_shape, a.dtype(), s);
|
||||||
|
in = scatter(tmp, std::vector<array>{}, in, std::vector<int>{}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_shape = in_shape;
|
||||||
|
if (real) {
|
||||||
|
auto ax = valid_axes.back();
|
||||||
|
out_shape[ax] = inverse ? n.back() : out_shape[ax] / 2 + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto in_type = real && !inverse ? float32 : complex64;
|
||||||
|
auto out_type = real && inverse ? float32 : complex64;
|
||||||
|
return array(
|
||||||
|
out_shape,
|
||||||
|
out_type,
|
||||||
|
std::make_unique<FFT>(to_stream(s), valid_axes, inverse, real),
|
||||||
|
{astype(in, in_type, s)});
|
||||||
|
}
|
||||||
|
|
||||||
|
array fft_impl(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool real,
|
||||||
|
bool inverse,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
std::vector<int> n;
|
||||||
|
for (auto ax : axes) {
|
||||||
|
n.push_back(a.shape(ax));
|
||||||
|
}
|
||||||
|
if (real && inverse) {
|
||||||
|
n.back() = (n.back() - 1) * 2;
|
||||||
|
}
|
||||||
|
return fft_impl(a, n, axes, real, inverse, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) {
|
||||||
|
std::vector<int> axes(a.ndim());
|
||||||
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
|
return fft_impl(a, axes, real, inverse, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array fftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, n, axes, false, false, s);
|
||||||
|
}
|
||||||
|
array fftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, axes, false, false, s);
|
||||||
|
}
|
||||||
|
array fftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, false, false, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array ifftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, n, axes, false, true, s);
|
||||||
|
}
|
||||||
|
array ifftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, axes, false, true, s);
|
||||||
|
}
|
||||||
|
array ifftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, false, true, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array rfftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, n, axes, true, false, s);
|
||||||
|
}
|
||||||
|
array rfftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, axes, true, false, s);
|
||||||
|
}
|
||||||
|
array rfftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, true, false, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array irfftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, n, axes, true, true, s);
|
||||||
|
}
|
||||||
|
array irfftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, axes, true, true, s);
|
||||||
|
}
|
||||||
|
array irfftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
return fft_impl(a, true, true, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::fft
|
21
mlx/graph_utils.h
Normal file
21
mlx/graph_utils.h
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
||||||
|
|
||||||
|
template <typename... Arrays>
|
||||||
|
void print_graph(std::ostream& os, Arrays... outputs) {
|
||||||
|
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
|
||||||
|
|
||||||
|
template <typename... Arrays>
|
||||||
|
void export_to_dot(std::ostream& os, Arrays... outputs) {
|
||||||
|
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
2323
mlx/ops.cpp
Normal file
2323
mlx/ops.cpp
Normal file
File diff suppressed because it is too large
Load Diff
16
mlx/transforms_impl.h
Normal file
16
mlx/transforms_impl.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
|
||||||
|
namespace mlx::core::detail {
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& in_axes);
|
||||||
|
|
||||||
|
std::vector<array> vmap_replace(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& s_inputs,
|
||||||
|
const std::vector<array>& s_outputs,
|
||||||
|
const std::vector<int>& in_axes,
|
||||||
|
const std::vector<int>& out_axes);
|
||||||
|
|
||||||
|
} // namespace mlx::core::detail
|
75
mlx/types/complex.h
Normal file
75
mlx/types/complex.h
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <complex>
|
||||||
|
#include "mlx/types/half_types.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct complex64_t;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static constexpr bool can_convert_to_complex64 =
|
||||||
|
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
||||||
|
|
||||||
|
struct complex64_t : public std::complex<float> {
|
||||||
|
complex64_t(float v, float u) : std::complex<float>(v, u){};
|
||||||
|
complex64_t(std::complex<float> v) : std::complex<float>(v){};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
|
||||||
|
complex64_t(T x) : std::complex<float>(x){};
|
||||||
|
|
||||||
|
operator float() const {
|
||||||
|
return real();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
inline bool operator>=(const complex64_t& a, const complex64_t& b) {
|
||||||
|
return (a.real() > b.real()) ||
|
||||||
|
(a.real() == b.real() && a.imag() >= b.imag());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool operator>(const complex64_t& a, const complex64_t& b) {
|
||||||
|
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
|
||||||
|
return operator>=(b, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool operator<(const complex64_t& a, const complex64_t& b) {
|
||||||
|
return operator>(b, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline complex64_t operator-(const complex64_t& v) {
|
||||||
|
return -static_cast<std::complex<float>>(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define complex_binop_helper(_op_, _operator_, itype) \
|
||||||
|
inline complex64_t _operator_(itype x, const complex64_t& y) { \
|
||||||
|
return x _op_ static_cast<std::complex<float>>(y); \
|
||||||
|
} \
|
||||||
|
inline complex64_t _operator_(const complex64_t& x, itype y) { \
|
||||||
|
return static_cast<std::complex<float>>(x) _op_ y; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define complex_binop(_op_, _operator_) \
|
||||||
|
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
|
||||||
|
return static_cast<std::complex<float>>(x) \
|
||||||
|
_op_ static_cast<std::complex<float>>(y); \
|
||||||
|
} \
|
||||||
|
complex_binop_helper(_op_, _operator_, bool) \
|
||||||
|
complex_binop_helper(_op_, _operator_, uint32_t) \
|
||||||
|
complex_binop_helper(_op_, _operator_, uint64_t) \
|
||||||
|
complex_binop_helper(_op_, _operator_, int32_t) \
|
||||||
|
complex_binop_helper(_op_, _operator_, int64_t) \
|
||||||
|
complex_binop_helper(_op_, _operator_, float16_t) \
|
||||||
|
complex_binop_helper(_op_, _operator_, bfloat16_t) \
|
||||||
|
complex_binop_helper(_op_, _operator_, const std::complex<float>&) \
|
||||||
|
complex_binop_helper(_op_, _operator_, float)
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
complex_binop(+, operator+)
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
232
mlx/types/fp16.h
Normal file
232
mlx/types/fp16.h
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define __MLX_HALF_NAN__ 0x7D00
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
union float_bits_fp16 {
|
||||||
|
float f;
|
||||||
|
uint32_t u;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
struct _MLX_Float16 {
|
||||||
|
uint16_t bits_;
|
||||||
|
|
||||||
|
// Default constructor
|
||||||
|
_MLX_Float16() = default;
|
||||||
|
|
||||||
|
// Default copy constructor
|
||||||
|
_MLX_Float16(_MLX_Float16 const&) = default;
|
||||||
|
|
||||||
|
// Appease std::vector<bool> for being special
|
||||||
|
_MLX_Float16& operator=(std::vector<bool>::reference x) {
|
||||||
|
bits_ = x;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
_MLX_Float16& operator=(const float& x) {
|
||||||
|
return (*this = _MLX_Float16(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
// From float32
|
||||||
|
_MLX_Float16(const float& x) : bits_(0) {
|
||||||
|
// Conversion following
|
||||||
|
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
||||||
|
|
||||||
|
// Union
|
||||||
|
float_bits_fp16 in;
|
||||||
|
|
||||||
|
// Take fp32 bits
|
||||||
|
in.f = x;
|
||||||
|
|
||||||
|
// Find and take sign bit
|
||||||
|
uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
|
||||||
|
uint16_t x_sign_16 = (x_sign_32 >> 16);
|
||||||
|
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
|
||||||
|
} else {
|
||||||
|
// Union
|
||||||
|
float_bits_fp16 inf_scale, zero_scale, magic_bits;
|
||||||
|
|
||||||
|
// Find exponent bits and take the max supported by half
|
||||||
|
uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
|
||||||
|
uint32_t max_expo_32 = uint32_t(0x38800000);
|
||||||
|
x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
|
||||||
|
x_expo_32 += uint32_t(15) << 23;
|
||||||
|
|
||||||
|
// Handle scaling to inf as needed
|
||||||
|
inf_scale.u = uint32_t(0x77800000);
|
||||||
|
zero_scale.u = uint32_t(0x08800000);
|
||||||
|
|
||||||
|
// Combine with magic and let addition do rouding
|
||||||
|
magic_bits.u = x_expo_32;
|
||||||
|
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
|
||||||
|
|
||||||
|
// Take the lower 5 bits of the exponent
|
||||||
|
uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
|
||||||
|
|
||||||
|
// Collect the lower 12 bits which have the mantissa
|
||||||
|
uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
|
||||||
|
|
||||||
|
// Combine sign, exp and mantissa
|
||||||
|
bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// To float32
|
||||||
|
operator float() const {
|
||||||
|
// Conversion following
|
||||||
|
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
||||||
|
|
||||||
|
// Union
|
||||||
|
float_bits_fp16 out;
|
||||||
|
|
||||||
|
uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
|
||||||
|
uint32_t base = (bits_ << 16);
|
||||||
|
uint32_t two_base = base + base;
|
||||||
|
|
||||||
|
uint32_t denorm_max = 1u << 27;
|
||||||
|
if (two_base < denorm_max) {
|
||||||
|
out.u = uint32_t(126) << 23; // magic mask
|
||||||
|
out.u |= (two_base >> 17); // Bits from fp16
|
||||||
|
out.f -= 0.5f; // magic bias
|
||||||
|
} else {
|
||||||
|
out.u = uint32_t(0xE0) << 23; // exponent offset
|
||||||
|
out.u += (two_base >> 4); // Bits from fp16
|
||||||
|
float out_unscaled = out.f; // Store value
|
||||||
|
out.u = uint32_t(0x7800000); // exponent scale
|
||||||
|
out.f *= out_unscaled;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add sign
|
||||||
|
out.u |= x_sign_32;
|
||||||
|
|
||||||
|
return out.f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||||
|
inline otype __operator__(atype lhs, btype rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||||
|
inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
} \
|
||||||
|
inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operators
|
||||||
|
#define half_binop(__op__, __operator__) \
|
||||||
|
half_binop_base( \
|
||||||
|
__op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, float, float, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, double, double, double); \
|
||||||
|
half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
|
||||||
|
|
||||||
|
half_binop(+, operator+);
|
||||||
|
half_binop(-, operator-);
|
||||||
|
half_binop(*, operator*);
|
||||||
|
half_binop(/, operator/);
|
||||||
|
|
||||||
|
#undef half_binop
|
||||||
|
|
||||||
|
// Comparison ops
|
||||||
|
#define half_compop(__op__, __operator__) \
|
||||||
|
half_binop_base( \
|
||||||
|
__op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, bool, float, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, bool, double, double); \
|
||||||
|
half_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||||
|
half_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||||
|
|
||||||
|
half_compop(>, operator>);
|
||||||
|
half_compop(<, operator<);
|
||||||
|
half_compop(>=, operator>=);
|
||||||
|
half_compop(<=, operator<=);
|
||||||
|
half_compop(==, operator==);
|
||||||
|
half_compop(!=, operator!=);
|
||||||
|
|
||||||
|
#undef half_compop
|
||||||
|
|
||||||
|
// Negative
|
||||||
|
inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
|
||||||
|
return -static_cast<float>(lhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inplace ops
|
||||||
|
#define half_inplace_op(__op__, __operator__) \
|
||||||
|
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
|
||||||
|
lhs = lhs __op__ rhs; \
|
||||||
|
return lhs; \
|
||||||
|
} \
|
||||||
|
inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
|
||||||
|
lhs = lhs __op__ rhs; \
|
||||||
|
return lhs; \
|
||||||
|
}
|
||||||
|
|
||||||
|
half_inplace_op(+, operator+=);
|
||||||
|
half_inplace_op(-, operator-=);
|
||||||
|
half_inplace_op(*, operator*=);
|
||||||
|
half_inplace_op(/, operator/=);
|
||||||
|
|
||||||
|
#undef half_inplace_op
|
||||||
|
|
||||||
|
// Bitwise ops
|
||||||
|
|
||||||
|
#define half_bitop(__op__, __operator__) \
|
||||||
|
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
|
||||||
|
_MLX_Float16 out; \
|
||||||
|
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
||||||
|
return out; \
|
||||||
|
} \
|
||||||
|
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
|
||||||
|
_MLX_Float16 out; \
|
||||||
|
out.bits_ = lhs.bits_ __op__ rhs; \
|
||||||
|
return out; \
|
||||||
|
} \
|
||||||
|
inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
|
||||||
|
_MLX_Float16 out; \
|
||||||
|
out.bits_ = lhs __op__ rhs.bits_; \
|
||||||
|
return out; \
|
||||||
|
}
|
||||||
|
|
||||||
|
half_bitop(|, operator|);
|
||||||
|
half_bitop(&, operator&);
|
||||||
|
half_bitop(^, operator^);
|
||||||
|
|
||||||
|
#undef half_bitop
|
||||||
|
|
||||||
|
#define half_inplace_bitop(__op__, __operator__) \
|
||||||
|
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
|
||||||
|
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
||||||
|
return lhs; \
|
||||||
|
} \
|
||||||
|
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
|
||||||
|
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
||||||
|
return lhs; \
|
||||||
|
}
|
||||||
|
|
||||||
|
half_inplace_bitop(|, operator|=);
|
||||||
|
half_inplace_bitop(&, operator&=);
|
||||||
|
half_inplace_bitop(^, operator^=);
|
||||||
|
|
||||||
|
#undef half_inplace_bitop
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
3
pyproject.toml
Normal file
3
pyproject.toml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
124
python/mlx/nn/layers/convolution.py
Normal file
124
python/mlx/nn/layers/convolution.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import math
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(Module):
|
||||||
|
"""Applies a 1-dimensional convolution over the multi-channel input sequence.
|
||||||
|
|
||||||
|
The channels are expected to be last i.e. the input shape should be ``NLC`` where:
|
||||||
|
- ``N`` is the batch dimension
|
||||||
|
- ``L`` is the sequence length
|
||||||
|
- ``C`` is the number of input channels
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channels
|
||||||
|
out_channels (int): The number of output channels
|
||||||
|
kernel_size (int): The size of the convolution filters
|
||||||
|
stride (int, optional): The stride when applying the filter.
|
||||||
|
Default: 1.
|
||||||
|
padding (int, optional): How many positions to 0-pad the input with.
|
||||||
|
Default: 0.
|
||||||
|
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||||
|
Default: ``True``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: int = 0,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
scale = math.sqrt(1 / (in_channels * kernel_size))
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(out_channels, kernel_size, in_channels),
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((out_channels,))
|
||||||
|
|
||||||
|
self.padding = padding
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return (
|
||||||
|
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||||
|
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
|
||||||
|
f"padding={self.padding}, bias={'bias' in self}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
y = mx.conv1d(x, self.weight, self.stride, self.padding)
|
||||||
|
if "bias" in self:
|
||||||
|
y = y + self.bias
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d(Module):
|
||||||
|
"""Applies a 2-dimensional convolution over the multi-channel input image.
|
||||||
|
|
||||||
|
The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
|
||||||
|
- ``N`` is the batch dimension
|
||||||
|
- ``H`` is the input image height
|
||||||
|
- ``W`` is the input image width
|
||||||
|
- ``C`` is the number of input channels
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channels.
|
||||||
|
out_channels (int): The number of output channels.
|
||||||
|
kernel_size (int or tuple): The size of the convolution filters.
|
||||||
|
stride (int or tuple, optional): The size of the stride when
|
||||||
|
applying the filter. Default: 0.
|
||||||
|
padding (int or tuple, optional): How many positions to 0-pad
|
||||||
|
the input with. Default: 0.
|
||||||
|
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||||
|
output. Default: ``True``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, tuple],
|
||||||
|
stride: Union[int, tuple] = 1,
|
||||||
|
padding: Union[int, tuple] = 0,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
kernel_size, stride, padding = map(
|
||||||
|
lambda x: (x, x) if isinstance(x, int) else x,
|
||||||
|
(kernel_size, stride, padding),
|
||||||
|
)
|
||||||
|
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(out_channels, *kernel_size, in_channels),
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((out_channels,))
|
||||||
|
|
||||||
|
self.padding = padding
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return (
|
||||||
|
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||||
|
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
||||||
|
f"padding={self.padding}, bias={'bias' in self}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
y = mx.conv2d(x, self.weight, self.stride, self.padding)
|
||||||
|
if "bias" in self:
|
||||||
|
y = y + self.bias
|
||||||
|
return y
|
28
python/mlx/nn/layers/embedding.py
Normal file
28
python/mlx/nn/layers/embedding.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(Module):
|
||||||
|
"""Implements a simple lookup table that maps each input integer to a
|
||||||
|
high-dimensional vector.
|
||||||
|
|
||||||
|
Typically used to embed discrete tokens for processing by neural networks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_embeddings (int): How many possible discrete tokens can we embed.
|
||||||
|
Usually called the vocabulary size.
|
||||||
|
dims (int): The dimensionality of the embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_embeddings: int, dims: int):
|
||||||
|
super().__init__()
|
||||||
|
scale = math.sqrt(1 / dims)
|
||||||
|
self.weight = mx.random.normal((num_embeddings, dims)) * scale
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.weight[x]
|
34
python/mlx/nn/layers/linear.py
Normal file
34
python/mlx/nn/layers/linear.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(Module):
|
||||||
|
"""Applies an affine transformation to the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dims (int): The dimensionality of the input features
|
||||||
|
output_dims (int): The dimensionality of the output features
|
||||||
|
bias (bool): If set to False then the layer will not use a bias
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
scale = math.sqrt(1 / input_dims)
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(output_dims, input_dims),
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((output_dims,))
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = x @ self.weight.T
|
||||||
|
if "bias" in self:
|
||||||
|
x = x + self.bias
|
||||||
|
return x
|
19
python/src/load.h
Normal file
19
python/src/load.h
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <variant>
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||||
|
|
||||||
|
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
||||||
|
void mlx_save_helper(py::object file, array a, bool retain_graph = true);
|
||||||
|
void mlx_savez_helper(
|
||||||
|
py::object file,
|
||||||
|
py::args args,
|
||||||
|
const py::kwargs& kwargs,
|
||||||
|
bool compressed = false);
|
31
python/src/mlx.cpp
Normal file
31
python/src/mlx.cpp
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
|
#define STRINGIFY(x) #x
|
||||||
|
#define TOSTRING(x) STRINGIFY(x)
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
void init_array(py::module_&);
|
||||||
|
void init_device(py::module_&);
|
||||||
|
void init_stream(py::module_&);
|
||||||
|
void init_metal(py::module_&);
|
||||||
|
void init_ops(py::module_&);
|
||||||
|
void init_transforms(py::module_&);
|
||||||
|
void init_random(py::module_&);
|
||||||
|
void init_fft(py::module_&);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(core, m) {
|
||||||
|
m.doc() = "mlx: A framework for machine learning on Apple Silicon.";
|
||||||
|
|
||||||
|
auto reprlib_fix = py::module_::import("mlx._reprlib_fix");
|
||||||
|
|
||||||
|
init_device(m);
|
||||||
|
init_stream(m);
|
||||||
|
init_array(m);
|
||||||
|
init_metal(m);
|
||||||
|
init_ops(m);
|
||||||
|
init_transforms(m);
|
||||||
|
init_random(m);
|
||||||
|
init_fft(m);
|
||||||
|
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||||
|
}
|
723
python/src/transforms.cpp
Normal file
723
python/src/transforms.cpp
Normal file
@ -0,0 +1,723 @@
|
|||||||
|
#include <pybind11/functional.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <fstream>
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/graph_utils.h"
|
||||||
|
#include "mlx/transforms.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace py::literals;
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||||
|
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
||||||
|
std::vector<T> vals;
|
||||||
|
if (auto pv = std::get_if<T>(&v); pv) {
|
||||||
|
vals.push_back(*pv);
|
||||||
|
} else {
|
||||||
|
vals = std::get<std::vector<T>>(v);
|
||||||
|
}
|
||||||
|
return vals;
|
||||||
|
}
|
||||||
|
|
||||||
|
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||||
|
std::function<void(py::handle)> recurse;
|
||||||
|
recurse = [&](py::handle subtree) {
|
||||||
|
if (py::isinstance<py::list>(subtree) ||
|
||||||
|
py::isinstance<py::tuple>(subtree)) {
|
||||||
|
for (auto item : subtree) {
|
||||||
|
recurse(item);
|
||||||
|
}
|
||||||
|
} else if (py::isinstance<py::dict>(subtree)) {
|
||||||
|
for (auto item : py::cast<py::dict>(subtree)) {
|
||||||
|
recurse(item.second);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
visitor(subtree);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
recurse(tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename V>
|
||||||
|
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
||||||
|
int len = py::cast<T>(subtrees[0]).size();
|
||||||
|
for (auto& subtree : subtrees) {
|
||||||
|
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
||||||
|
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object tree_map(
|
||||||
|
const std::vector<py::object>& trees,
|
||||||
|
std::function<py::object(const std::vector<py::object>&)> transform) {
|
||||||
|
std::function<py::object(const std::vector<py::object>&)> recurse;
|
||||||
|
|
||||||
|
recurse = [&](const std::vector<py::object>& subtrees) {
|
||||||
|
if (py::isinstance<py::list>(subtrees[0])) {
|
||||||
|
py::list l;
|
||||||
|
std::vector<py::object> items(subtrees.size());
|
||||||
|
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
||||||
|
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
||||||
|
for (int j = 0; j < subtrees.size(); ++j) {
|
||||||
|
if (py::isinstance<py::list>(subtrees[j])) {
|
||||||
|
items[j] = py::cast<py::list>(subtrees[j])[i];
|
||||||
|
} else {
|
||||||
|
items[j] = subtrees[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.append(recurse(items));
|
||||||
|
}
|
||||||
|
return py::cast<py::object>(l);
|
||||||
|
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
||||||
|
// Check the rest of the subtrees
|
||||||
|
std::vector<py::object> items(subtrees.size());
|
||||||
|
int len = py::cast<py::tuple>(subtrees[0]).size();
|
||||||
|
py::tuple l(len);
|
||||||
|
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
for (int j = 0; j < subtrees.size(); ++j) {
|
||||||
|
if (py::isinstance<py::tuple>(subtrees[j])) {
|
||||||
|
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
||||||
|
} else {
|
||||||
|
items[j] = subtrees[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l[i] = recurse(items);
|
||||||
|
}
|
||||||
|
return py::cast<py::object>(l);
|
||||||
|
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
||||||
|
std::vector<py::object> items(subtrees.size());
|
||||||
|
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
||||||
|
py::dict d;
|
||||||
|
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
||||||
|
for (int j = 0; j < subtrees.size(); ++j) {
|
||||||
|
if (py::isinstance<py::dict>(subtrees[j])) {
|
||||||
|
auto subdict = py::cast<py::dict>(subtrees[j]);
|
||||||
|
if (!subdict.contains(item.first)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
||||||
|
}
|
||||||
|
items[j] = subdict[item.first];
|
||||||
|
} else {
|
||||||
|
items[j] = subtrees[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d[item.first] = recurse(items);
|
||||||
|
}
|
||||||
|
return py::cast<py::object>(d);
|
||||||
|
} else {
|
||||||
|
return transform(subtrees);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return recurse(trees);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object tree_map(
|
||||||
|
py::object tree,
|
||||||
|
std::function<py::object(py::handle)> transform) {
|
||||||
|
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
||||||
|
return transform(inputs[0]);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
||||||
|
std::vector<array> flat_tree;
|
||||||
|
|
||||||
|
tree_visit(tree, [&](py::handle obj) {
|
||||||
|
if (py::isinstance<array>(obj)) {
|
||||||
|
flat_tree.push_back(py::cast<array>(obj));
|
||||||
|
} else if (strict) {
|
||||||
|
throw std::invalid_argument("Argument is not an array");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return flat_tree;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object tree_unflatten(
|
||||||
|
py::object tree,
|
||||||
|
const std::vector<array>& values,
|
||||||
|
int index = 0) {
|
||||||
|
return tree_map(tree, [&](py::handle obj) {
|
||||||
|
if (py::isinstance<array>(obj)) {
|
||||||
|
return py::cast(values[index++]);
|
||||||
|
} else {
|
||||||
|
return py::cast<py::object>(obj);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto validate_argnums_argnames(
|
||||||
|
const std::optional<IntOrVec>& argnums,
|
||||||
|
const StrOrVec& argnames) {
|
||||||
|
auto vec_names = to_vector(argnames);
|
||||||
|
|
||||||
|
if (!argnums.has_value()) {
|
||||||
|
// argnums was not provided and argnames was empty
|
||||||
|
if (vec_names.empty()) {
|
||||||
|
return std::make_pair(std::vector<int>{0}, vec_names);
|
||||||
|
} else {
|
||||||
|
return std::make_pair(std::vector<int>{}, vec_names);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_pair(to_vector(*argnums), vec_names);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto py_value_and_grad(
|
||||||
|
const py::function& fun,
|
||||||
|
std::vector<int> argnums,
|
||||||
|
std::vector<std::string> argnames,
|
||||||
|
const std::string& error_msg_tag,
|
||||||
|
bool scalar_func_only) {
|
||||||
|
// Sanitize argnums
|
||||||
|
if (argnums.size() == 0 && argnames.size() == 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
error_msg_tag + " Gradient wrt no argument requested");
|
||||||
|
}
|
||||||
|
if (argnums.size() > 0) {
|
||||||
|
std::sort(argnums.begin(), argnums.end());
|
||||||
|
if (argnums[0] < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << error_msg_tag
|
||||||
|
<< " Can't compute the gradient of negative argument index "
|
||||||
|
<< argnums[0];
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
||||||
|
const py::args& args, const py::kwargs& kwargs) {
|
||||||
|
// Sanitize the input
|
||||||
|
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << error_msg_tag << " Can't compute the gradient of argument index "
|
||||||
|
<< argnums.back() << " because the function is called with only "
|
||||||
|
<< args.size() << " arguments.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& key : argnames) {
|
||||||
|
if (!kwargs.contains(key)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << error_msg_tag
|
||||||
|
<< " Can't compute the gradient of keyword argument '" << key
|
||||||
|
<< "' because the function is called with the "
|
||||||
|
<< "following keyword arguments {";
|
||||||
|
for (auto item : kwargs) {
|
||||||
|
msg << item.first.cast<std::string>() << ",";
|
||||||
|
}
|
||||||
|
msg << "}";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect the arrays
|
||||||
|
std::vector<array> arrays;
|
||||||
|
std::vector<int> counts(1, 0);
|
||||||
|
for (auto i : argnums) {
|
||||||
|
auto argsi = tree_flatten(args[i]);
|
||||||
|
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
|
||||||
|
counts.push_back(argsi.size());
|
||||||
|
}
|
||||||
|
for (auto& key : argnames) {
|
||||||
|
auto argsk = tree_flatten(kwargs[key.c_str()]);
|
||||||
|
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
|
||||||
|
counts.push_back(argsk.size());
|
||||||
|
}
|
||||||
|
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
||||||
|
std::vector<int> gradient_indices(arrays.size());
|
||||||
|
std::iota(gradient_indices.begin(), gradient_indices.end(), 0);
|
||||||
|
|
||||||
|
// value_out will hold the output of the python function in order to be
|
||||||
|
// able to reconstruct the python tree of extra return values
|
||||||
|
py::object py_value_out;
|
||||||
|
auto value_and_grads = value_and_grad(
|
||||||
|
[&fun,
|
||||||
|
&args,
|
||||||
|
&kwargs,
|
||||||
|
&argnums,
|
||||||
|
&argnames,
|
||||||
|
&counts,
|
||||||
|
&py_value_out,
|
||||||
|
&error_msg_tag,
|
||||||
|
scalar_func_only](const std::vector<array>& a) {
|
||||||
|
// Copy the arguments
|
||||||
|
py::args args_cpy = py::tuple(args.size());
|
||||||
|
py::kwargs kwargs_cpy = py::kwargs();
|
||||||
|
int j = 0;
|
||||||
|
for (int i = 0; i < args.size(); ++i) {
|
||||||
|
if (j < argnums.size() && i == argnums[j]) {
|
||||||
|
args_cpy[i] = tree_unflatten(args[i], a, counts[j]);
|
||||||
|
j++;
|
||||||
|
} else {
|
||||||
|
args_cpy[i] = args[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& key : argnames) {
|
||||||
|
kwargs_cpy[key.c_str()] =
|
||||||
|
tree_unflatten(kwargs[key.c_str()], a, counts[j]);
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
for (auto item : kwargs) {
|
||||||
|
if (kwargs_cpy.contains(item.first)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
kwargs_cpy[item.first] = item.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the python function
|
||||||
|
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
||||||
|
|
||||||
|
// Validate the return value of the python function
|
||||||
|
if (!py::isinstance<array>(py_value_out)) {
|
||||||
|
if (scalar_func_only) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << error_msg_tag << " The return value of the function "
|
||||||
|
<< "whose gradient we want to compute should be a "
|
||||||
|
<< "scalar array; but " << py_value_out.get_type()
|
||||||
|
<< " was returned.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (!py::isinstance<py::tuple>(py_value_out)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << error_msg_tag << " The return value of the function "
|
||||||
|
<< "whose gradient we want to compute should be either a "
|
||||||
|
<< "scalar array or a tuple with the first value being a "
|
||||||
|
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
|
||||||
|
<< py_value_out.get_type() << " was returned.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
py::tuple ret = py::cast<py::tuple>(py_value_out);
|
||||||
|
if (ret.size() == 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << error_msg_tag << " The return value of the function "
|
||||||
|
<< "whose gradient we want to compute should be either a "
|
||||||
|
<< "scalar array or a non-empty tuple. The first value should be a "
|
||||||
|
<< "scalar array and the rest can be anything. Instead, "
|
||||||
|
<< "we got an empty tuple.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (!py::isinstance<array>(ret[0])) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << error_msg_tag << " The return value of the function "
|
||||||
|
<< "whose gradient we want to compute should be either a "
|
||||||
|
<< "scalar array or a tuple with the first value being a "
|
||||||
|
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
|
||||||
|
<< "was a tuple with the first value being of type "
|
||||||
|
<< ret[0].get_type() << " .";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tree_flatten(py_value_out, false);
|
||||||
|
},
|
||||||
|
gradient_indices)(arrays);
|
||||||
|
|
||||||
|
auto value = value_and_grads.first;
|
||||||
|
auto gradients = value_and_grads.second;
|
||||||
|
|
||||||
|
// Put the gradients back in their container.
|
||||||
|
// We have the following cases:
|
||||||
|
//
|
||||||
|
// 1. Single python positional argument has a gradient (eg argnums=[0])
|
||||||
|
// 2. Many python positional arguments have gradients (eg argnums=[0, 1])
|
||||||
|
// 3. A python keyword argument has gradients
|
||||||
|
//
|
||||||
|
// In case 1 we return the original python variable but with the gradients.
|
||||||
|
// In case 2 we return a tuple of the above.
|
||||||
|
// In case 3 we return a tuple containing a tuple and dict (sth like
|
||||||
|
// (tuple(), dict(x=mx.array(5))) ).
|
||||||
|
py::object positional_grads;
|
||||||
|
py::object keyword_grads;
|
||||||
|
py::object py_grads;
|
||||||
|
|
||||||
|
// Collect the gradients for the positional arguments
|
||||||
|
if (argnums.size() == 1) {
|
||||||
|
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
|
||||||
|
} else if (argnums.size() > 1) {
|
||||||
|
py::tuple grads_(argnums.size());
|
||||||
|
for (int i = 0; i < argnums.size(); i++) {
|
||||||
|
grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]);
|
||||||
|
}
|
||||||
|
positional_grads = py::cast<py::object>(grads_);
|
||||||
|
} else {
|
||||||
|
positional_grads = py::none();
|
||||||
|
}
|
||||||
|
|
||||||
|
// No keyword argument gradients so return the tuple of gradients
|
||||||
|
if (argnames.size() == 0) {
|
||||||
|
py_grads = positional_grads;
|
||||||
|
} else {
|
||||||
|
py::dict grads_;
|
||||||
|
for (int i = 0; i < argnames.size(); i++) {
|
||||||
|
auto& k = argnames[i];
|
||||||
|
grads_[k.c_str()] = tree_unflatten(
|
||||||
|
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
|
||||||
|
}
|
||||||
|
keyword_grads = py::cast<py::object>(grads_);
|
||||||
|
|
||||||
|
py_grads =
|
||||||
|
py::cast<py::object>(py::make_tuple(positional_grads, keyword_grads));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put the values back in the container
|
||||||
|
py::object return_value = tree_unflatten(py_value_out, value);
|
||||||
|
return std::make_pair(return_value, py_grads);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto py_vmap(
|
||||||
|
const py::function& fun,
|
||||||
|
const py::object& in_axes,
|
||||||
|
const py::object& out_axes) {
|
||||||
|
return [fun, in_axes, out_axes](const py::args& args) {
|
||||||
|
auto axes_to_flat_tree = [](const py::object& tree,
|
||||||
|
const py::object& axes) {
|
||||||
|
auto tree_axes = tree_map(
|
||||||
|
{tree, axes},
|
||||||
|
[](const std::vector<py::object>& inputs) { return inputs[1]; });
|
||||||
|
std::vector<int> flat_axes;
|
||||||
|
tree_visit(tree_axes, [&flat_axes](py::handle obj) {
|
||||||
|
if (obj.is_none()) {
|
||||||
|
flat_axes.push_back(-1);
|
||||||
|
} else if (py::isinstance<py::int_>(obj)) {
|
||||||
|
flat_axes.push_back(py::cast<int>(py::cast<py::int_>(obj)));
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return flat_axes;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Inputs must be array or tree of arrays
|
||||||
|
auto inputs = tree_flatten(args, true);
|
||||||
|
auto flat_in_axes = axes_to_flat_tree(args, in_axes);
|
||||||
|
|
||||||
|
// py_value_out will hold the output of the python function in order to be
|
||||||
|
// able to reconstruct the python tree of extra return values
|
||||||
|
py::object py_outputs;
|
||||||
|
|
||||||
|
auto vmap_fn =
|
||||||
|
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
||||||
|
// Call the python function
|
||||||
|
py_outputs = fun(*tree_unflatten(args, a));
|
||||||
|
|
||||||
|
// Flatten the outputs
|
||||||
|
return tree_flatten(py_outputs, true);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [trace_inputs, trace_outputs] =
|
||||||
|
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||||
|
|
||||||
|
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes);
|
||||||
|
|
||||||
|
// Perform the vmap
|
||||||
|
auto outputs = detail::vmap_replace(
|
||||||
|
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
|
||||||
|
|
||||||
|
// Put the outputs back in the container
|
||||||
|
return tree_unflatten(py_outputs, outputs);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_transforms(py::module_& m) {
|
||||||
|
m.def(
|
||||||
|
"eval",
|
||||||
|
[](const py::args& args, bool retain_graph) {
|
||||||
|
std::vector<array> arrays = tree_flatten(args);
|
||||||
|
eval(arrays, retain_graph);
|
||||||
|
},
|
||||||
|
"retain_graph"_a = false,
|
||||||
|
R"pbdoc(
|
||||||
|
Evaluate an :class:`array` or tree of :class:`array`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (arrays or trees of arrays): Each argument can be a single array
|
||||||
|
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||||
|
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
|
||||||
|
an :class:`array`.
|
||||||
|
retain_graph (bool): Indicate that the graph structure should be
|
||||||
|
preserved. This option is intended to enable function transforms
|
||||||
|
which contain control flow based on the value of an array.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"jvp",
|
||||||
|
[](const py::function& fun,
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents) {
|
||||||
|
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||||
|
py::args args = py::tuple(primals.size());
|
||||||
|
for (int i = 0; i < primals.size(); ++i) {
|
||||||
|
args[i] = primals[i];
|
||||||
|
}
|
||||||
|
auto out = fun(*args);
|
||||||
|
if (py::isinstance<array>(out)) {
|
||||||
|
return std::vector<array>{py::cast<array>(out)};
|
||||||
|
} else {
|
||||||
|
return py::cast<std::vector<array>>(out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return jvp(vfun, primals, tangents);
|
||||||
|
},
|
||||||
|
"fun"_a,
|
||||||
|
"primals"_a,
|
||||||
|
"tangents"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Compute the Jacobian-vector product.
|
||||||
|
|
||||||
|
This computes the product of the Jacobian of a function ``fun`` evaluated
|
||||||
|
at ``primals`` with the ``tangents``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun (function): A function which takes a variable number of :class:`array`
|
||||||
|
and returns a single :class:`array` or list of :class:`array`.
|
||||||
|
primals (list(array)): A list of :class:`array` at which to
|
||||||
|
evaluate the Jacobian.
|
||||||
|
tangents (list(array)): A list of :class:`array` which are the
|
||||||
|
"vector" in the Jacobian-vector product. The ``tangents`` should be the
|
||||||
|
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list(array): A list of the Jacobian-vector products which
|
||||||
|
is the same in number, shape, and type of the inputs to ``fun``.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"vjp",
|
||||||
|
[](const py::function& fun,
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents) {
|
||||||
|
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||||
|
py::args args = py::tuple(primals.size());
|
||||||
|
for (int i = 0; i < primals.size(); ++i) {
|
||||||
|
args[i] = primals[i];
|
||||||
|
}
|
||||||
|
auto out = fun(*args);
|
||||||
|
if (py::isinstance<array>(out)) {
|
||||||
|
return std::vector<array>{py::cast<array>(out)};
|
||||||
|
} else {
|
||||||
|
return py::cast<std::vector<array>>(out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return vjp(vfun, primals, cotangents);
|
||||||
|
},
|
||||||
|
"fun"_a,
|
||||||
|
"primals"_a,
|
||||||
|
"cotangents"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Compute the vector-Jacobian product.
|
||||||
|
|
||||||
|
Computes the product of the ``cotangents`` with the Jacobian of a
|
||||||
|
function ``fun`` evaluated at ``primals``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun (function): A function which takes a variable number of :class:`array`
|
||||||
|
and returns a single :class:`array` or list of :class:`array`.
|
||||||
|
primals (list(array)): A list of :class:`array` at which to
|
||||||
|
evaluate the Jacobian.
|
||||||
|
cotangents (list(array)): A list of :class:`array` which are the
|
||||||
|
"vector" in the vector-Jacobian product. The ``cotangents`` should be the
|
||||||
|
same in number, shape, and type as the outputs of ``fun``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list(array): A list of the vector-Jacobian products which
|
||||||
|
is the same in number, shape, and type of the outputs of ``fun``.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"value_and_grad",
|
||||||
|
[](const py::function& fun,
|
||||||
|
const std::optional<IntOrVec>& argnums,
|
||||||
|
const StrOrVec& argnames) {
|
||||||
|
auto [argnums_vec, argnames_vec] =
|
||||||
|
validate_argnums_argnames(argnums, argnames);
|
||||||
|
return py::cpp_function(py_value_and_grad(
|
||||||
|
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
|
||||||
|
},
|
||||||
|
"fun"_a,
|
||||||
|
"argnums"_a = std::nullopt,
|
||||||
|
"argnames"_a = std::vector<std::string>{},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns a function which computes the value and gradient of ``fun``.
|
||||||
|
|
||||||
|
The function passed to :func:`value_and_grad` should return either
|
||||||
|
a scalar loss or a tuple in which the first element is a scalar
|
||||||
|
loss and the remaining elements can be anything.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
def mse(params, inputs, targets):
|
||||||
|
outputs = forward(params, inputs)
|
||||||
|
lvalue = (outputs - targets).square().mean()
|
||||||
|
return lvalue
|
||||||
|
|
||||||
|
# Returns lvalue, dlvalue/dparams
|
||||||
|
lvalue, grads = mx.value_and_grad(mse)
|
||||||
|
|
||||||
|
def lasso(params, inputs, targets, a=1.0, b=1.0):
|
||||||
|
outputs = forward(params, inputs)
|
||||||
|
mse = (outputs - targets).square().mean()
|
||||||
|
l1 = mx.abs(outputs - targets).mean()
|
||||||
|
|
||||||
|
loss = a*mse + b*l1
|
||||||
|
|
||||||
|
return loss, mse, l1
|
||||||
|
|
||||||
|
(loss, mse, l1), grads = mx.value_and_grad(lasso)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun (function): A function which takes a variable number of
|
||||||
|
:class:`array` or trees of :class:`array` and returns
|
||||||
|
a scalar output :class:`array` or a tuple the first element
|
||||||
|
of which should be a scalar :class:`array`.
|
||||||
|
argnums (int or list(int), optional): Specify the index (or indices)
|
||||||
|
of the positional arguments of ``fun`` to compute the gradient
|
||||||
|
with respect to. If neither ``argnums`` nor ``argnames`` are
|
||||||
|
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
|
||||||
|
argument.
|
||||||
|
argnames (str or list(str), optional): Specify keyword arguments of
|
||||||
|
``fun`` to compute gradients with respect to. It defaults to [] so
|
||||||
|
no gradients for keyword arguments by default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: A function which returns a tuple where the first element
|
||||||
|
is the output of `fun` and the second element is the gradients w.r.t.
|
||||||
|
the loss.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"grad",
|
||||||
|
[](const py::function& fun,
|
||||||
|
const std::optional<IntOrVec>& argnums,
|
||||||
|
const StrOrVec& argnames) {
|
||||||
|
auto [argnums_vec, argnames_vec] =
|
||||||
|
validate_argnums_argnames(argnums, argnames);
|
||||||
|
auto fn =
|
||||||
|
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
|
||||||
|
return py::cpp_function(
|
||||||
|
[fn](const py::args& args, const py::kwargs& kwargs) {
|
||||||
|
return fn(args, kwargs).second;
|
||||||
|
});
|
||||||
|
},
|
||||||
|
"fun"_a,
|
||||||
|
"argnums"_a = std::nullopt,
|
||||||
|
"argnames"_a = std::vector<std::string>{},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns a function which computes the gradient of ``fun``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun (function): A function which takes a variable number of
|
||||||
|
:class:`array` or trees of :class:`array` and returns
|
||||||
|
a scalar output :class:`array`.
|
||||||
|
argnums (int or list(int), optional): Specify the index (or indices)
|
||||||
|
of the positional arguments of ``fun`` to compute the gradient
|
||||||
|
with respect to. If neither ``argnums`` nor ``argnames`` are
|
||||||
|
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
|
||||||
|
argument.
|
||||||
|
argnames (str or list(str), optional): Specify keyword arguments of
|
||||||
|
``fun`` to compute gradients with respect to. It defaults to [] so
|
||||||
|
no gradients for keyword arguments by default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: A function which has the same input arguments as ``fun`` and
|
||||||
|
returns the gradient(s).
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"vmap",
|
||||||
|
[](const py::function& fun,
|
||||||
|
const py::object& in_axes,
|
||||||
|
const py::object& out_axes) {
|
||||||
|
return py::cpp_function(py_vmap(fun, in_axes, out_axes));
|
||||||
|
},
|
||||||
|
"fun"_a,
|
||||||
|
"in_axes"_a = 0,
|
||||||
|
"out_axes"_a = 0,
|
||||||
|
R"pbdoc(
|
||||||
|
Returns a vectorized version of ``fun``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun (function): A function which takes a variable number of
|
||||||
|
:class:`array` or a tree of :class:`array` and returns
|
||||||
|
a variable number of :class:`array` or a tree of :class:`array`.
|
||||||
|
in_axes (int, optional): An integer or a valid prefix tree of the
|
||||||
|
inputs to ``fun`` where each node specifies the vmapped axis. If
|
||||||
|
the value is ``None`` then the corresponding input(s) are not vmapped.
|
||||||
|
Defaults to ``0``.
|
||||||
|
out_axes (int, optional): An integer or a valid prefix tree of the
|
||||||
|
outputs of ``fun`` where each node specifies the vmapped axis. If
|
||||||
|
the value is ``None`` then the corresponding outputs(s) are not vmapped.
|
||||||
|
Defaults to ``0``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: The vectorized function.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"simplify",
|
||||||
|
[](const py::args& args) {
|
||||||
|
std::vector<array> arrays = tree_flatten(args);
|
||||||
|
simplify(arrays);
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Simplify the graph that computes the arrays.
|
||||||
|
|
||||||
|
Run a few fast graph simplification operations to reuse computation and
|
||||||
|
reduce memory consumption. This function is meant to be run every time
|
||||||
|
so its overhead should be small, approximately 1ms for a graph with a
|
||||||
|
few thousand nodes.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
def foo(x):
|
||||||
|
y = x @ x
|
||||||
|
z = x @ x
|
||||||
|
return y + z
|
||||||
|
|
||||||
|
x = mx.ones((10, 10))
|
||||||
|
y = foo(x)
|
||||||
|
z = foo(x)
|
||||||
|
|
||||||
|
# Computes the matmul twice
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
# Computes the matmul once
|
||||||
|
mx.simplify(z)
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Any number of arrays and/or trees of arrays to be simplified.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"export_to_dot",
|
||||||
|
[](py::object file, const py::args& args) {
|
||||||
|
std::vector<array> arrays = tree_flatten(args);
|
||||||
|
if (py::isinstance<py::str>(file)) {
|
||||||
|
std::ofstream out(py::cast<std::string>(file));
|
||||||
|
export_to_dot(out, arrays);
|
||||||
|
} else if (py::hasattr(file, "write")) {
|
||||||
|
std::ostringstream out;
|
||||||
|
export_to_dot(out, arrays);
|
||||||
|
auto write = file.attr("write");
|
||||||
|
write(out.str());
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"export_to_dot accepts file-like objects or strings to be used as filenames");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"file"_a);
|
||||||
|
}
|
16
python/tests/mlx_tests.py
Normal file
16
python/tests/mlx_tests.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
class MLXTestCase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.default = mx.default_device()
|
||||||
|
device = os.getenv("DEVICE", None)
|
||||||
|
if device is not None:
|
||||||
|
device = getattr(mx, device)
|
||||||
|
mx.set_default_device(device)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
mx.set_default_device(self.default)
|
445
python/tests/test_blas.py
Normal file
445
python/tests/test_blas.py
Normal file
@ -0,0 +1,445 @@
|
|||||||
|
import unittest
|
||||||
|
from itertools import permutations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlas(mlx_tests.MLXTestCase):
|
||||||
|
@property
|
||||||
|
def dtypes(self):
|
||||||
|
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||||
|
|
||||||
|
def __gemm_test(
|
||||||
|
self,
|
||||||
|
shape_a,
|
||||||
|
shape_b,
|
||||||
|
np_dtype=np.float32,
|
||||||
|
f_np_a=lambda x: x,
|
||||||
|
f_np_b=lambda x: x,
|
||||||
|
f_mx_a=lambda x: x,
|
||||||
|
f_mx_b=lambda x: x,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=np.dtype(np_dtype).name, shape_a=shape_a, shape_b=shape_b
|
||||||
|
):
|
||||||
|
np.random.seed(42)
|
||||||
|
scale = max(np.sum(shape_a), 128)
|
||||||
|
a_np = np.random.normal(0.0, 1.0 / scale, shape_a).astype(np_dtype)
|
||||||
|
b_np = np.random.normal(0.0, 1.0 / scale, shape_b).astype(np_dtype)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_np = f_np_a(a_np.astype(np.float32))
|
||||||
|
b_np = f_np_b(b_np.astype(np.float32))
|
||||||
|
a_mx = f_mx_a(a_mx)
|
||||||
|
b_mx = f_mx_b(b_mx)
|
||||||
|
|
||||||
|
out_npy = a_np @ b_np
|
||||||
|
out_mlx = a_mx @ b_mx
|
||||||
|
|
||||||
|
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
|
||||||
|
|
||||||
|
def test_matmul_unaligned(self):
|
||||||
|
if not mx.metal.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
for dtype in self.dtypes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
base_shapes = [4, 8, 16, 32, 64, 128]
|
||||||
|
pertubations = [-2, -1, 0, 1, 2]
|
||||||
|
|
||||||
|
for dim in base_shapes:
|
||||||
|
for p in pertubations:
|
||||||
|
shape_a = (dim + p, dim + p)
|
||||||
|
shape_b = (dim + p, dim + p)
|
||||||
|
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||||
|
|
||||||
|
def test_matmul_shapes(self):
|
||||||
|
if not mx.metal.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
shapes = [
|
||||||
|
(1, 2, 1, 1),
|
||||||
|
(1, 1, 2, 1),
|
||||||
|
(3, 23, 457, 3),
|
||||||
|
]
|
||||||
|
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
shapes += [
|
||||||
|
(16, 768, 768, 128),
|
||||||
|
]
|
||||||
|
|
||||||
|
for dtype in self.dtypes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
for B, M, N, K in shapes:
|
||||||
|
|
||||||
|
with self.subTest(tranpose="nn"):
|
||||||
|
shape_a = (B, M, K)
|
||||||
|
shape_b = (B, K, N)
|
||||||
|
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||||
|
|
||||||
|
with self.subTest(tranpose="nt"):
|
||||||
|
shape_a = (B, M, K)
|
||||||
|
shape_b = (B, N, K)
|
||||||
|
self.__gemm_test(
|
||||||
|
shape_a,
|
||||||
|
shape_b,
|
||||||
|
np_dtype,
|
||||||
|
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
|
||||||
|
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.subTest(tranpose="tn"):
|
||||||
|
shape_a = (B, K, M)
|
||||||
|
shape_b = (B, K, N)
|
||||||
|
self.__gemm_test(
|
||||||
|
shape_a,
|
||||||
|
shape_b,
|
||||||
|
np_dtype,
|
||||||
|
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
|
||||||
|
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.subTest(tranpose="tt"):
|
||||||
|
shape_a = (B, K, M)
|
||||||
|
shape_b = (B, N, K)
|
||||||
|
self.__gemm_test(
|
||||||
|
shape_a,
|
||||||
|
shape_b,
|
||||||
|
np_dtype,
|
||||||
|
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
|
||||||
|
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||||
|
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
|
||||||
|
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_matmul(self):
|
||||||
|
# Note: so far, matmul only works with floating-point types
|
||||||
|
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
|
||||||
|
b = mx.array([[0.0, -1.0], [-3.0, 3.0]])
|
||||||
|
|
||||||
|
expected = [[-6.0, 5.0], [-12.0, 9.0]]
|
||||||
|
|
||||||
|
self.assertEqual((a @ b).tolist(), expected)
|
||||||
|
self.assertEqual(mx.matmul(a, b).tolist(), expected)
|
||||||
|
|
||||||
|
# Transposed matmul
|
||||||
|
np.random.seed(0)
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||||
|
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
|
||||||
|
d_npy = np.transpose(a_npy, (1, 0)) @ b_npy
|
||||||
|
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
|
||||||
|
d_mlx = mx.transpose(a_mlx, (1, 0)) @ b_mlx
|
||||||
|
|
||||||
|
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||||
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
|
||||||
|
|
||||||
|
def test_matmul_dtypes(self):
|
||||||
|
|
||||||
|
for dt in self.dtypes:
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||||
|
getattr(np, dt)
|
||||||
|
)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||||
|
getattr(np, dt)
|
||||||
|
)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
|
c_npy = np.matmul(a_npy, b_npy, dtype=getattr(np, dt))
|
||||||
|
c_mlx = a_mlx @ b_mlx
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
|
||||||
|
def test_matmul_batched(self):
|
||||||
|
np.random.seed(0)
|
||||||
|
# Batched matmul
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||||
|
c_npy = a_npy @ b_npy
|
||||||
|
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
c_mlx = a_mlx @ b_mlx
|
||||||
|
|
||||||
|
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
|
||||||
|
# Batched and transposed matmul
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
c_npy = a_npy @ np.transpose(b_npy, (0, 2, 1))
|
||||||
|
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 2, 1))
|
||||||
|
|
||||||
|
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
|
||||||
|
# Batched matmul with simple broadast
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||||
|
c_npy = a_npy @ b_npy
|
||||||
|
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
c_mlx = a_mlx @ b_mlx
|
||||||
|
|
||||||
|
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
|
||||||
|
# Both operands broadcasted
|
||||||
|
d_npy = np.broadcast_to(b_npy, (5, 16, 16))
|
||||||
|
d_mlx = mx.broadcast_to(b_mlx, (5, 16, 16))
|
||||||
|
|
||||||
|
e_npy = d_npy @ d_npy
|
||||||
|
e_mlx = d_mlx @ d_mlx
|
||||||
|
|
||||||
|
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
|
||||||
|
|
||||||
|
# Batched and transposed matmul with simple broadast
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
|
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
|
||||||
|
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
|
||||||
|
|
||||||
|
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
|
||||||
|
# Matmul with vector
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
|
c_npy = a_npy @ b_npy
|
||||||
|
c_mlx = a_mlx @ b_mlx
|
||||||
|
|
||||||
|
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
|
||||||
|
# Test Multiheaded attention style matmul
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
|
a_npy = np.transpose(a_npy, (0, 2, 1, 3))
|
||||||
|
b_npy = np.transpose(b_npy, (0, 2, 1, 3))
|
||||||
|
a_mlx = mx.transpose(a_mlx, (0, 2, 1, 3))
|
||||||
|
b_mlx = mx.transpose(b_mlx, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
c_npy = a_npy @ np.transpose(b_npy, (0, 1, 3, 2))
|
||||||
|
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 1, 3, 2))
|
||||||
|
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
|
||||||
|
def __gemv_test(
|
||||||
|
self,
|
||||||
|
shape_mat,
|
||||||
|
shape_vec,
|
||||||
|
np_dtype=np.float32,
|
||||||
|
mat_first=True,
|
||||||
|
np_mat_f=lambda x: x,
|
||||||
|
np_vec_f=lambda x: x,
|
||||||
|
mlx_mat_f=lambda x: x,
|
||||||
|
mlx_vec_f=lambda x: x,
|
||||||
|
):
|
||||||
|
with self.subTest(shape=shape_mat):
|
||||||
|
np.random.seed(42)
|
||||||
|
scale = max(np.sum(shape_mat), 32)
|
||||||
|
mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)
|
||||||
|
vec_npy = np.random.normal(0.0, 1.0 / scale, shape_vec).astype(np_dtype)
|
||||||
|
|
||||||
|
mat_mlx = mx.array(mat_npy)
|
||||||
|
vec_mlx = mx.array(vec_npy)
|
||||||
|
|
||||||
|
mat_npy = np_mat_f(mat_npy)
|
||||||
|
vec_npy = np_vec_f(vec_npy)
|
||||||
|
mat_mlx = mlx_mat_f(mat_mlx)
|
||||||
|
vec_mlx = mlx_vec_f(vec_mlx)
|
||||||
|
|
||||||
|
if mat_first:
|
||||||
|
out_npy = mat_npy @ vec_npy
|
||||||
|
out_mlx = mat_mlx @ vec_mlx
|
||||||
|
else:
|
||||||
|
out_npy = vec_npy @ mat_npy
|
||||||
|
out_mlx = vec_mlx @ mat_mlx
|
||||||
|
|
||||||
|
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(out_mlx, out_npy, atol=1e-5))
|
||||||
|
|
||||||
|
def test_matrix_vector(self):
|
||||||
|
for dtype in self.dtypes:
|
||||||
|
with self.subTest(dtype=dtype):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
# Basic square matrix test
|
||||||
|
self.__gemv_test(
|
||||||
|
shape_mat=(64, 64), shape_vec=(64, 1), np_dtype=np_dtype
|
||||||
|
)
|
||||||
|
self.__gemv_test(
|
||||||
|
shape_mat=(64, 64),
|
||||||
|
shape_vec=(64, 1),
|
||||||
|
np_dtype=np_dtype,
|
||||||
|
mat_first=False,
|
||||||
|
np_vec_f=lambda x: np.transpose(x, (1, 0)),
|
||||||
|
mlx_vec_f=lambda x: mx.transpose(x, (1, 0)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Vector matrix product with aligned and unaligned shapes
|
||||||
|
for in_len_base, out_len_base in (
|
||||||
|
(2, 2),
|
||||||
|
(32, 32),
|
||||||
|
(64, 64),
|
||||||
|
(2048, 2048),
|
||||||
|
):
|
||||||
|
for mi in (-1, 0, 1):
|
||||||
|
for mj in (-1, 0, 1):
|
||||||
|
# Vec mat
|
||||||
|
shape_mat = (in_len_base + mi, out_len_base + mj)
|
||||||
|
shape_vec = (1, in_len_base + mi)
|
||||||
|
self.__gemv_test(
|
||||||
|
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mat vec
|
||||||
|
shape_mat = (out_len_base + mj, in_len_base + mi)
|
||||||
|
shape_vec = (in_len_base + mi, 1)
|
||||||
|
self.__gemv_test(
|
||||||
|
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_matrix_vector_batched(self):
|
||||||
|
for dtype in self.dtypes:
|
||||||
|
with self.subTest(dtype=dtype):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
# Batched mat vec
|
||||||
|
for shape_mat, shape_vec in (
|
||||||
|
((32, 128, 64), (32, 64, 1)),
|
||||||
|
((128, 64), (32, 64, 1)),
|
||||||
|
((32, 128, 64), (64, 1)),
|
||||||
|
):
|
||||||
|
self.__gemv_test(
|
||||||
|
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batched vec mat
|
||||||
|
for shape_vec, shape_mat in (
|
||||||
|
((32, 1, 128), (32, 128, 64)),
|
||||||
|
((32, 1, 128), (128, 64)),
|
||||||
|
((1, 128), (32, 128, 64)),
|
||||||
|
):
|
||||||
|
self.__gemv_test(
|
||||||
|
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_matrix_vector_broadcast(self):
|
||||||
|
for dtype in self.dtypes:
|
||||||
|
with self.subTest(dtype=dtype):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
# Different broadcasts mat vec
|
||||||
|
for shape_mat, shape_vec in (
|
||||||
|
((32, 64, 64), (32, 64, 1)),
|
||||||
|
((64, 64), (32, 64, 1)),
|
||||||
|
((32, 64, 64), (64, 1)),
|
||||||
|
):
|
||||||
|
self.__gemv_test(
|
||||||
|
shape_mat=(64, 64),
|
||||||
|
shape_vec=(64, 1),
|
||||||
|
np_dtype=np_dtype,
|
||||||
|
np_mat_f=(lambda mat_npy: np.broadcast_to(mat_npy, shape_mat)),
|
||||||
|
np_vec_f=(lambda vec_npy: np.broadcast_to(vec_npy, shape_vec)),
|
||||||
|
mlx_mat_f=(lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat)),
|
||||||
|
mlx_vec_f=(lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Different broadcasts vec mat
|
||||||
|
for shape_vec, shape_mat in (
|
||||||
|
((32, 1, 64), (32, 64, 64)),
|
||||||
|
((32, 1, 64), (64, 64)),
|
||||||
|
((1, 64), (32, 64, 64)),
|
||||||
|
):
|
||||||
|
self.__gemv_test(
|
||||||
|
shape_mat=(64, 64),
|
||||||
|
shape_vec=(1, 64),
|
||||||
|
np_dtype=np_dtype,
|
||||||
|
mat_first=False,
|
||||||
|
np_mat_f=lambda mat_npy: np.broadcast_to(mat_npy, shape_mat),
|
||||||
|
np_vec_f=lambda vec_npy: np.broadcast_to(vec_npy, shape_vec),
|
||||||
|
mlx_mat_f=lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat),
|
||||||
|
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_matrix_vector_edgecases(self):
|
||||||
|
for dtype in self.dtypes:
|
||||||
|
with self.subTest(dtype=dtype):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
for in_vec_len in np.arange(1, 5):
|
||||||
|
for out_vec_len in np.arange(1, 5):
|
||||||
|
for batch_size in np.arange(1, 5):
|
||||||
|
with self.subTest(
|
||||||
|
problem_shape=(batch_size, in_vec_len, out_vec_len)
|
||||||
|
):
|
||||||
|
# Matrix vector
|
||||||
|
with self.subTest(transpose=False):
|
||||||
|
a_npy = np.ones(
|
||||||
|
(batch_size, out_vec_len, in_vec_len),
|
||||||
|
dtype=np_dtype,
|
||||||
|
)
|
||||||
|
b_npy = np.ones(
|
||||||
|
(batch_size, in_vec_len, 1), dtype=np_dtype
|
||||||
|
)
|
||||||
|
for i in range(batch_size):
|
||||||
|
b_npy[i] *= i + 1.0
|
||||||
|
|
||||||
|
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
|
||||||
|
c_npy = a_npy @ b_npy
|
||||||
|
c_mlx = a_mlx @ b_mlx
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(c_npy.shape), list(c_mlx.shape)
|
||||||
|
)
|
||||||
|
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
||||||
|
|
||||||
|
# Vector matrix
|
||||||
|
with self.subTest(transpose=True):
|
||||||
|
a_npy = np.ones(
|
||||||
|
(batch_size, out_vec_len, in_vec_len),
|
||||||
|
dtype=np_dtype,
|
||||||
|
)
|
||||||
|
b_npy = np.ones(
|
||||||
|
(batch_size, 1, out_vec_len), dtype=np_dtype
|
||||||
|
)
|
||||||
|
for i in range(batch_size):
|
||||||
|
b_npy[i] *= i + 1.0
|
||||||
|
|
||||||
|
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
|
||||||
|
c_npy = b_npy @ a_npy
|
||||||
|
c_mlx = b_mlx @ a_mlx
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(c_npy.shape), list(c_mlx.shape)
|
||||||
|
)
|
||||||
|
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
445
python/tests/test_conv.py
Normal file
445
python/tests/test_conv.py
Normal file
@ -0,0 +1,445 @@
|
|||||||
|
import unittest
|
||||||
|
from itertools import permutations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
has_torch = True
|
||||||
|
except ImportError as e:
|
||||||
|
has_torch = False
|
||||||
|
|
||||||
|
|
||||||
|
class TestConv(mlx_tests.MLXTestCase):
|
||||||
|
def test_numpy_conv(self):
|
||||||
|
for dtype in (
|
||||||
|
"float16",
|
||||||
|
"float32",
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
for M, N, mode in (
|
||||||
|
(1, 1, "full"),
|
||||||
|
(25, 5, "full"),
|
||||||
|
(24, 5, "same"),
|
||||||
|
(24, 4, "same"),
|
||||||
|
(24, 4, "valid"),
|
||||||
|
(4, 24, "full"),
|
||||||
|
(5, 25, "same"),
|
||||||
|
(4, 25, "valid"),
|
||||||
|
):
|
||||||
|
with self.subTest(dtype=dtype, M=M, N=N, mode=mode):
|
||||||
|
atol = 1e-6 if dtype == "float32" else 1e-5
|
||||||
|
a_np = np.random.rand(M).astype(np_dtype)
|
||||||
|
v_np = np.random.rand(N).astype(np_dtype)
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
|
c_np = np.convolve(a_np, v_np, mode=mode)
|
||||||
|
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
|
||||||
|
|
||||||
|
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
|
||||||
|
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_conv_1D(self):
|
||||||
|
def run_conv1D(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
iH,
|
||||||
|
kH,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
dtype="float32",
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=dtype,
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
O=O,
|
||||||
|
iH=iH,
|
||||||
|
kH=kH,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||||
|
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||||
|
|
||||||
|
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||||
|
in_pt, wt_pt = map(
|
||||||
|
lambda x: torch.from_numpy(x.transpose(0, 2, 1)), (in_np, wt_np)
|
||||||
|
)
|
||||||
|
|
||||||
|
out_mx = mx.conv1d(
|
||||||
|
in_mx,
|
||||||
|
wt_mx,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
out_pt = torch.conv1d(
|
||||||
|
in_pt,
|
||||||
|
wt_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
out_pt = torch.transpose(out_pt, 2, 1)
|
||||||
|
|
||||||
|
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||||
|
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C, O in (
|
||||||
|
(1, 1, 1),
|
||||||
|
(1, 6, 1),
|
||||||
|
(1, 1, 6),
|
||||||
|
(4, 32, 64),
|
||||||
|
):
|
||||||
|
for iH, kH, stride, padding in (
|
||||||
|
(1, 1, 1, 0),
|
||||||
|
(3, 3, 1, 0),
|
||||||
|
(31, 5, 5, 2),
|
||||||
|
):
|
||||||
|
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||||
|
|
||||||
|
# Strided inputs tests
|
||||||
|
for tpose_in, tpose_wt in (
|
||||||
|
((0, 2, 1), (0, 1, 2)),
|
||||||
|
((0, 2, 1), (0, 2, 1)),
|
||||||
|
):
|
||||||
|
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
|
||||||
|
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||||
|
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||||
|
|
||||||
|
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||||
|
in_mx_t = mx.transpose(in_mx, tpose_in)
|
||||||
|
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
|
||||||
|
out_mx = mx.conv1d(in_mx_t, wt_mx_t)
|
||||||
|
|
||||||
|
in_pt, wt_pt = map(
|
||||||
|
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
|
||||||
|
(in_np.transpose(tpose_in), wt_np.transpose(tpose_wt)),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_pt = torch.conv1d(in_pt, wt_pt)
|
||||||
|
out_pt = torch.transpose(out_pt, 2, 1)
|
||||||
|
|
||||||
|
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||||
|
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_conv_1D_grad(self):
|
||||||
|
def run_conv1D_grad(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
iH,
|
||||||
|
kH,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
dtype="float32",
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=dtype,
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
O=O,
|
||||||
|
iH=iH,
|
||||||
|
kH=kH,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
|
||||||
|
|
||||||
|
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||||
|
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||||
|
ct_np = np.random.normal(0, 1.0 / C, (N, oH, O)).astype(np_dtype)
|
||||||
|
|
||||||
|
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||||
|
in_pt, wt_pt, ct_pt = map(
|
||||||
|
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
|
||||||
|
(in_np, wt_np, ct_np),
|
||||||
|
)
|
||||||
|
|
||||||
|
def f(a, b):
|
||||||
|
return mx.conv1d(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, outs_mx = mx.vjp(
|
||||||
|
f,
|
||||||
|
[
|
||||||
|
in_mx,
|
||||||
|
wt_mx,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
ct_mx,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pt_grad_in = F.grad.conv1d_input(
|
||||||
|
in_pt.shape,
|
||||||
|
wt_pt,
|
||||||
|
ct_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
pt_grad_wt = F.grad.conv1d_weight(
|
||||||
|
in_pt,
|
||||||
|
wt_pt.shape,
|
||||||
|
ct_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
pt_grad_in = torch.transpose(pt_grad_in, 2, 1).numpy()
|
||||||
|
pt_grad_wt = torch.transpose(pt_grad_wt, 2, 1).numpy()
|
||||||
|
|
||||||
|
mx_grad_in, mx_grad_wt = outs_mx
|
||||||
|
|
||||||
|
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||||
|
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||||
|
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||||
|
|
||||||
|
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||||
|
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||||
|
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C, O in (
|
||||||
|
(1, 1, 1),
|
||||||
|
(1, 6, 1),
|
||||||
|
(1, 1, 6),
|
||||||
|
(4, 32, 64),
|
||||||
|
):
|
||||||
|
for iH, kH, stride, padding in (
|
||||||
|
(1, 1, 1, 0),
|
||||||
|
(3, 3, 1, 0),
|
||||||
|
(31, 5, 5, 2),
|
||||||
|
):
|
||||||
|
run_conv1D_grad(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_conv_2D(self):
|
||||||
|
def run_conv2D(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
idim,
|
||||||
|
kdim,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation=(1, 1),
|
||||||
|
groups=1,
|
||||||
|
dtype="float32",
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=dtype,
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
O=O,
|
||||||
|
idim=idim,
|
||||||
|
kdim=kdim,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
iH, iW = idim
|
||||||
|
kH, kW = kdim
|
||||||
|
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||||
|
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||||
|
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
|
||||||
|
|
||||||
|
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||||
|
in_pt, wt_pt = map(
|
||||||
|
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
||||||
|
(in_np, wt_np),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(
|
||||||
|
in_mx,
|
||||||
|
wt_mx,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
in_pt,
|
||||||
|
wt_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||||
|
|
||||||
|
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
|
||||||
|
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C, O in (
|
||||||
|
(1, 1, 1),
|
||||||
|
(1, 6, 1),
|
||||||
|
(1, 1, 6),
|
||||||
|
(4, 32, 64),
|
||||||
|
):
|
||||||
|
for idim, kdim, stride, padding in (
|
||||||
|
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||||
|
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||||
|
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||||
|
):
|
||||||
|
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_conv_2D_grad(self):
|
||||||
|
def run_conv2D_grad(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
idim,
|
||||||
|
kdim,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation=(1, 1),
|
||||||
|
groups=1,
|
||||||
|
dtype="float32",
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=dtype,
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
O=O,
|
||||||
|
idim=idim,
|
||||||
|
kdim=kdim,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
iH, iW = idim
|
||||||
|
kH, kW = kdim
|
||||||
|
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||||
|
|
||||||
|
oH = 1 + (
|
||||||
|
(iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0]
|
||||||
|
)
|
||||||
|
oW = 1 + (
|
||||||
|
(iW + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||||
|
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||||
|
ct_np = np.random.normal(0.0, scale, (N, oH, oW, O)).astype(np_dtype)
|
||||||
|
|
||||||
|
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||||
|
in_pt, wt_pt, ct_pt = map(
|
||||||
|
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
||||||
|
(in_np, wt_np, ct_np),
|
||||||
|
)
|
||||||
|
|
||||||
|
def f(a, b):
|
||||||
|
return mx.conv2d(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, outs_mx = mx.vjp(
|
||||||
|
f,
|
||||||
|
[
|
||||||
|
in_mx,
|
||||||
|
wt_mx,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
ct_mx,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pt_grad_in = F.grad.conv1d_input(
|
||||||
|
in_pt.shape,
|
||||||
|
wt_pt,
|
||||||
|
ct_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
pt_grad_wt = F.grad.conv1d_weight(
|
||||||
|
in_pt,
|
||||||
|
wt_pt.shape,
|
||||||
|
ct_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 1)).numpy()
|
||||||
|
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 1)).numpy()
|
||||||
|
|
||||||
|
mx_grad_in, mx_grad_wt = outs_mx
|
||||||
|
|
||||||
|
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||||
|
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||||
|
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||||
|
|
||||||
|
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||||
|
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||||
|
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C, O in (
|
||||||
|
(1, 1, 1),
|
||||||
|
(1, 6, 1),
|
||||||
|
(1, 1, 6),
|
||||||
|
(4, 32, 64),
|
||||||
|
):
|
||||||
|
for idim, kdim, stride, padding in (
|
||||||
|
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||||
|
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||||
|
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||||
|
):
|
||||||
|
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
157
python/tests/test_load.py
Normal file
157
python/tests/test_load.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import unittest
|
||||||
|
import os
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoad(mlx_tests.MLXTestCase):
|
||||||
|
dtypes = [
|
||||||
|
"uint8",
|
||||||
|
"uint16",
|
||||||
|
"uint32",
|
||||||
|
"uint64",
|
||||||
|
"int8",
|
||||||
|
"int16",
|
||||||
|
"int32",
|
||||||
|
"int64",
|
||||||
|
"float32",
|
||||||
|
"float16",
|
||||||
|
"complex64",
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||||
|
cls.test_dir = cls.test_dir_fid.name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.test_dir_fid.cleanup()
|
||||||
|
|
||||||
|
def test_save_and_load(self):
|
||||||
|
|
||||||
|
if not os.path.isdir(self.test_dir):
|
||||||
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
|
for dt in self.dtypes:
|
||||||
|
with self.subTest(dtype=dt):
|
||||||
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy")
|
||||||
|
save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy")
|
||||||
|
|
||||||
|
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
||||||
|
save_arr_npy = save_arr.astype(getattr(np, dt))
|
||||||
|
save_arr_mlx = mx.array(save_arr_npy)
|
||||||
|
|
||||||
|
mx.save(save_file_mlx, save_arr_mlx)
|
||||||
|
np.save(save_file_npy, save_arr_npy)
|
||||||
|
|
||||||
|
# Load array saved by mlx as mlx array
|
||||||
|
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||||
|
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
||||||
|
|
||||||
|
# Load array saved by numpy as mlx array
|
||||||
|
load_arr_npy_mlx = mx.load(save_file_npy)
|
||||||
|
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
||||||
|
|
||||||
|
# Load array saved by mlx as numpy array
|
||||||
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||||
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||||
|
|
||||||
|
def test_save_and_load_fs(self):
|
||||||
|
|
||||||
|
if not os.path.isdir(self.test_dir):
|
||||||
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
|
for dt in self.dtypes:
|
||||||
|
with self.subTest(dtype=dt):
|
||||||
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
save_file_mlx = os.path.join(
|
||||||
|
self.test_dir, f"mlx_{dt}_{i}_fs.npy"
|
||||||
|
)
|
||||||
|
save_file_npy = os.path.join(
|
||||||
|
self.test_dir, f"npy_{dt}_{i}_fs.npy"
|
||||||
|
)
|
||||||
|
|
||||||
|
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
||||||
|
save_arr_npy = save_arr.astype(getattr(np, dt))
|
||||||
|
save_arr_mlx = mx.array(save_arr_npy)
|
||||||
|
|
||||||
|
with open(save_file_mlx, "wb") as f:
|
||||||
|
mx.save(f, save_arr_mlx)
|
||||||
|
|
||||||
|
np.save(save_file_npy, save_arr_npy)
|
||||||
|
|
||||||
|
# Load array saved by mlx as mlx array
|
||||||
|
with open(save_file_mlx, "rb") as f:
|
||||||
|
load_arr_mlx_mlx = mx.load(f)
|
||||||
|
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
||||||
|
|
||||||
|
# Load array saved by numpy as mlx array
|
||||||
|
with open(save_file_npy, "rb") as f:
|
||||||
|
load_arr_npy_mlx = mx.load(f)
|
||||||
|
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
||||||
|
|
||||||
|
# Load array saved by mlx as numpy array
|
||||||
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||||
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||||
|
|
||||||
|
def test_savez_and_loadz(self):
|
||||||
|
if not os.path.isdir(self.test_dir):
|
||||||
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
|
for dt in self.dtypes:
|
||||||
|
with self.subTest(dtype=dt):
|
||||||
|
shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]
|
||||||
|
save_file_mlx_uncomp = os.path.join(
|
||||||
|
self.test_dir, f"mlx_{dt}_uncomp.npz"
|
||||||
|
)
|
||||||
|
save_file_npy_uncomp = os.path.join(
|
||||||
|
self.test_dir, f"npy_{dt}_uncomp.npz"
|
||||||
|
)
|
||||||
|
save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz")
|
||||||
|
save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz")
|
||||||
|
|
||||||
|
# Make dictionary of multiple
|
||||||
|
save_arrs_npy = {
|
||||||
|
f"save_arr_{i}": np.random.uniform(
|
||||||
|
0.0, 32.0, size=shapes[i]
|
||||||
|
).astype(getattr(np, dt))
|
||||||
|
for i in range(len(shapes))
|
||||||
|
}
|
||||||
|
save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}
|
||||||
|
|
||||||
|
# Save as npz files
|
||||||
|
np.savez(save_file_npy_uncomp, **save_arrs_npy)
|
||||||
|
mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)
|
||||||
|
np.savez_compressed(save_file_npy_comp, **save_arrs_npy)
|
||||||
|
mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)
|
||||||
|
|
||||||
|
for save_file_npy, save_file_mlx in (
|
||||||
|
(save_file_npy_uncomp, save_file_mlx_uncomp),
|
||||||
|
(save_file_npy_comp, save_file_mlx_comp),
|
||||||
|
):
|
||||||
|
|
||||||
|
# Load array saved by mlx as mlx array
|
||||||
|
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||||
|
for k, v in load_arr_mlx_mlx.items():
|
||||||
|
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
||||||
|
|
||||||
|
# Load arrays saved by numpy as mlx arrays
|
||||||
|
load_arr_npy_mlx = mx.load(save_file_npy)
|
||||||
|
for k, v in load_arr_npy_mlx.items():
|
||||||
|
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
||||||
|
|
||||||
|
# Load array saved by mlx as numpy array
|
||||||
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||||
|
for k, v in load_arr_mlx_npy.items():
|
||||||
|
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
231
python/tests/test_nn.py
Normal file
231
python/tests/test_nn.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestNN(mlx_tests.MLXTestCase):
|
||||||
|
def test_linear(self):
|
||||||
|
inputs = mx.zeros((10, 4))
|
||||||
|
layer = nn.Linear(input_dims=4, output_dims=8)
|
||||||
|
outputs = layer(inputs)
|
||||||
|
self.assertEqual(tuple(outputs.shape), (10, 8))
|
||||||
|
|
||||||
|
def test_cross_entropy(self):
|
||||||
|
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||||
|
targets = mx.array([0, 1])
|
||||||
|
losses = nn.losses.cross_entropy(logits, targets)
|
||||||
|
self.assertTrue(mx.array_equal(losses, mx.zeros((2,))))
|
||||||
|
|
||||||
|
def test_gelu(self):
|
||||||
|
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
|
||||||
|
|
||||||
|
# From: jax.nn.gelu(np.array(inputs), approximate=False)
|
||||||
|
expected = np.array(
|
||||||
|
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
||||||
|
)
|
||||||
|
|
||||||
|
out = nn.GELU()(mx.array(inputs))
|
||||||
|
self.assertTrue(np.allclose(out, expected))
|
||||||
|
|
||||||
|
# Crudely check the approximations
|
||||||
|
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||||
|
y = nn.gelu(x)
|
||||||
|
y_hat1 = nn.gelu_approx(x)
|
||||||
|
y_hat2 = nn.gelu_fast_approx(x)
|
||||||
|
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||||
|
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||||
|
|
||||||
|
def test_group_norm(self):
|
||||||
|
x = mx.arange(100, dtype=mx.float32)
|
||||||
|
x = x.reshape(1, 10, 10, 1)
|
||||||
|
x = mx.broadcast_to(x, (2, 10, 10, 4))
|
||||||
|
x = mx.concatenate([x, 0.5 * x], axis=-1)
|
||||||
|
|
||||||
|
# Group norm in groups last mode
|
||||||
|
g = nn.GroupNorm(2, 8)
|
||||||
|
y = g(x)
|
||||||
|
means = y.reshape(2, -1, 2).mean(axis=1)
|
||||||
|
var = y.reshape(2, -1, 2).var(axis=1)
|
||||||
|
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
||||||
|
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
||||||
|
g.weight = g.weight * 2
|
||||||
|
g.bias = g.bias + 3
|
||||||
|
y = g(x)
|
||||||
|
means = y.reshape(2, -1, 2).mean(axis=1)
|
||||||
|
var = y.reshape(2, -1, 2).var(axis=1)
|
||||||
|
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||||
|
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||||
|
|
||||||
|
# Group norm in groups first mode
|
||||||
|
g = nn.GroupNorm(2, 8, pytorch_compatible=True)
|
||||||
|
y = g(x)
|
||||||
|
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
||||||
|
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
||||||
|
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
||||||
|
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
||||||
|
g.weight = g.weight * 2
|
||||||
|
g.bias = g.bias + 3
|
||||||
|
y = g(x)
|
||||||
|
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
||||||
|
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
||||||
|
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||||
|
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||||
|
|
||||||
|
def test_conv1d(self):
|
||||||
|
N = 5
|
||||||
|
L = 12
|
||||||
|
ks = 3
|
||||||
|
C_in = 2
|
||||||
|
C_out = 4
|
||||||
|
x = mx.ones((N, L, C_in))
|
||||||
|
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
|
||||||
|
c.weight = mx.ones_like(c.weight)
|
||||||
|
y = c(x)
|
||||||
|
self.assertEqual(y.shape, [N, L - ks + 1, C_out])
|
||||||
|
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
|
||||||
|
|
||||||
|
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
|
||||||
|
y = c(x)
|
||||||
|
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out])
|
||||||
|
self.assertTrue("bias" in c.parameters())
|
||||||
|
|
||||||
|
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
|
||||||
|
self.assertTrue("bias" not in c.parameters())
|
||||||
|
|
||||||
|
def test_conv2d(self):
|
||||||
|
x = mx.ones((4, 8, 8, 3))
|
||||||
|
c = nn.Conv2d(3, 1, 8)
|
||||||
|
y = c(x)
|
||||||
|
self.assertEqual(y.shape, [4, 1, 1, 1])
|
||||||
|
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
|
||||||
|
y = c(x)
|
||||||
|
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
|
||||||
|
|
||||||
|
# 3x3 conv no padding stride 1
|
||||||
|
c = nn.Conv2d(3, 8, 3)
|
||||||
|
y = c(x)
|
||||||
|
self.assertEqual(y.shape, [4, 6, 6, 8])
|
||||||
|
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||||
|
|
||||||
|
# 3x3 conv padding 1 stride 1
|
||||||
|
c = nn.Conv2d(3, 8, 3, padding=1)
|
||||||
|
y = c(x)
|
||||||
|
self.assertEqual(y.shape, [4, 8, 8, 8])
|
||||||
|
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||||
|
self.assertLess(
|
||||||
|
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
|
||||||
|
1e-4,
|
||||||
|
)
|
||||||
|
self.assertLess(
|
||||||
|
mx.abs(y[:, 7, 7] - c.weight[:, :-1, :-1].sum(axis=(1, 2, 3))).max(),
|
||||||
|
1e-4,
|
||||||
|
)
|
||||||
|
self.assertLess(
|
||||||
|
mx.abs(y[:, 1:7, 7] - c.weight[:, :, :-1].sum(axis=(1, 2, 3))).max(),
|
||||||
|
1e-4,
|
||||||
|
)
|
||||||
|
self.assertLess(
|
||||||
|
mx.abs(y[:, 7, 1:7] - c.weight[:, :-1, :].sum(axis=(1, 2, 3))).max(),
|
||||||
|
1e-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3x3 conv no padding stride 2
|
||||||
|
c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
|
||||||
|
y = c(x)
|
||||||
|
self.assertEqual(y.shape, [4, 3, 3, 8])
|
||||||
|
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||||
|
|
||||||
|
def test_sequential(self):
|
||||||
|
x = mx.ones((10, 2))
|
||||||
|
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
|
||||||
|
y = m(x)
|
||||||
|
self.assertEqual(y.shape, [10, 1])
|
||||||
|
params = m.parameters()
|
||||||
|
self.assertTrue("layers" in params)
|
||||||
|
self.assertEqual(len(params["layers"]), 3)
|
||||||
|
self.assertTrue("weight" in params["layers"][0])
|
||||||
|
self.assertEqual(len(params["layers"][1]), 0)
|
||||||
|
self.assertTrue("weight" in params["layers"][2])
|
||||||
|
|
||||||
|
m.layers[1] = nn.relu
|
||||||
|
y2 = m(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, y2))
|
||||||
|
|
||||||
|
def test_module_utilities(self):
|
||||||
|
m = nn.Sequential(
|
||||||
|
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
||||||
|
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
|
||||||
|
nn.Linear(10, 1),
|
||||||
|
mx.sigmoid,
|
||||||
|
)
|
||||||
|
|
||||||
|
children = m.children()
|
||||||
|
self.assertTrue(isinstance(children, dict))
|
||||||
|
self.assertEqual(len(children), 1)
|
||||||
|
self.assertTrue(isinstance(children["layers"], list))
|
||||||
|
self.assertEqual(len(children["layers"]), 4)
|
||||||
|
self.assertEqual(children["layers"][3], {})
|
||||||
|
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
|
||||||
|
self.assertEqual(len(flat_children), 3)
|
||||||
|
|
||||||
|
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||||
|
self.assertEqual(len(leaves), 4)
|
||||||
|
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||||
|
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||||
|
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||||
|
self.assertEqual(leaves[3][0], "layers.2")
|
||||||
|
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||||
|
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||||
|
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||||
|
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||||
|
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
def assert_not_training(k, m):
|
||||||
|
self.assertFalse(m.training)
|
||||||
|
|
||||||
|
m.apply_to_modules(assert_not_training)
|
||||||
|
|
||||||
|
m.train()
|
||||||
|
|
||||||
|
def assert_training(k, m):
|
||||||
|
self.assertTrue(m.training)
|
||||||
|
|
||||||
|
m.apply_to_modules(assert_training)
|
||||||
|
|
||||||
|
def test_sin_pe(self):
|
||||||
|
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||||
|
x = mx.arange(10)
|
||||||
|
y = m(x)
|
||||||
|
|
||||||
|
self.assertEqual(y.shape, [10, 16])
|
||||||
|
similarities = y @ y.T
|
||||||
|
self.assertLess(
|
||||||
|
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_io(self):
|
||||||
|
def make_model():
|
||||||
|
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||||
|
|
||||||
|
m = make_model()
|
||||||
|
tdir = tempfile.TemporaryDirectory()
|
||||||
|
file = os.path.join(tdir.name, "model.npz")
|
||||||
|
m.save_weights(file)
|
||||||
|
m_load = make_model()
|
||||||
|
m_load.load_weights(file)
|
||||||
|
tdir.cleanup()
|
||||||
|
|
||||||
|
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||||
|
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
29
python/tests/test_optimizers.py
Normal file
29
python/tests/test_optimizers.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.optimizers as opt
|
||||||
|
import mlx.utils
|
||||||
|
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||||
|
def test_optimizers(self):
|
||||||
|
params = {
|
||||||
|
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||||
|
"second": mx.zeros((1,)),
|
||||||
|
}
|
||||||
|
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
|
||||||
|
|
||||||
|
for optim in [opt.SGD(0.1), opt.Adam(0.1)]:
|
||||||
|
update = optim.apply_gradients(grads, params)
|
||||||
|
mx.eval(update)
|
||||||
|
equal_shape = mlx.utils.tree_map(
|
||||||
|
lambda x, y: x.shape == y.shape, params, update
|
||||||
|
)
|
||||||
|
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||||
|
self.assertTrue(all_equal)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
192
python/tests/test_random.py
Normal file
192
python/tests/test_random.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandom(mlx_tests.MLXTestCase):
|
||||||
|
def test_global_rng(self):
|
||||||
|
mx.random.seed(3)
|
||||||
|
a = mx.random.uniform()
|
||||||
|
b = mx.random.uniform()
|
||||||
|
|
||||||
|
mx.random.seed(3)
|
||||||
|
x = mx.random.uniform()
|
||||||
|
y = mx.random.uniform()
|
||||||
|
|
||||||
|
self.assertEqual(a.item(), x.item())
|
||||||
|
self.assertEqual(y.item(), b.item())
|
||||||
|
|
||||||
|
def test_key(self):
|
||||||
|
k1 = mx.random.key(0)
|
||||||
|
k2 = mx.random.key(0)
|
||||||
|
self.assertTrue(mx.array_equal(k1, k2))
|
||||||
|
|
||||||
|
k2 = mx.random.key(1)
|
||||||
|
self.assertFalse(mx.array_equal(k1, k2))
|
||||||
|
|
||||||
|
def test_key_split(self):
|
||||||
|
key = mx.random.key(0)
|
||||||
|
|
||||||
|
k1, k2 = mx.random.split(key)
|
||||||
|
self.assertFalse(mx.array_equal(k1, k2))
|
||||||
|
|
||||||
|
r1, r2 = mx.random.split(key)
|
||||||
|
self.assertTrue(mx.array_equal(k1, r1))
|
||||||
|
self.assertTrue(mx.array_equal(k2, r2))
|
||||||
|
|
||||||
|
keys = mx.random.split(key, 10)
|
||||||
|
self.assertEqual(keys.shape, [10, 2])
|
||||||
|
|
||||||
|
def test_uniform(self):
|
||||||
|
key = mx.random.key(0)
|
||||||
|
a = mx.random.uniform(key=key)
|
||||||
|
self.assertEqual(a.shape, [])
|
||||||
|
self.assertEqual(a.dtype, mx.float32)
|
||||||
|
|
||||||
|
b = mx.random.uniform(key=key)
|
||||||
|
self.assertEqual(a.item(), b.item())
|
||||||
|
|
||||||
|
a = mx.random.uniform(shape=(2, 3))
|
||||||
|
self.assertEqual(a.shape, [2, 3])
|
||||||
|
|
||||||
|
a = mx.random.uniform(shape=(1000,), low=-1, high=5)
|
||||||
|
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||||
|
|
||||||
|
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
|
||||||
|
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||||
|
|
||||||
|
def test_normal(self):
|
||||||
|
key = mx.random.key(0)
|
||||||
|
a = mx.random.normal(key=key)
|
||||||
|
self.assertEqual(a.shape, [])
|
||||||
|
self.assertEqual(a.dtype, mx.float32)
|
||||||
|
|
||||||
|
b = mx.random.normal(key=key)
|
||||||
|
self.assertEqual(a.item(), b.item())
|
||||||
|
|
||||||
|
a = mx.random.normal(shape=(2, 3))
|
||||||
|
self.assertEqual(a.shape, [2, 3])
|
||||||
|
|
||||||
|
## Generate in float16 or bfloat16
|
||||||
|
for t in [mx.float16, mx.bfloat16]:
|
||||||
|
a = mx.random.normal(dtype=t)
|
||||||
|
self.assertEqual(a.dtype, t)
|
||||||
|
|
||||||
|
def test_randint(self):
|
||||||
|
a = mx.random.randint(0, 1, [])
|
||||||
|
self.assertEqual(a.shape, [])
|
||||||
|
self.assertEqual(a.dtype, mx.int32)
|
||||||
|
|
||||||
|
shape = [88]
|
||||||
|
low = mx.array(3)
|
||||||
|
high = mx.array(15)
|
||||||
|
|
||||||
|
key = mx.random.key(0)
|
||||||
|
a = mx.random.randint(low, high, shape, key=key)
|
||||||
|
self.assertEqual(a.shape, shape)
|
||||||
|
self.assertEqual(a.dtype, mx.int32)
|
||||||
|
|
||||||
|
# Check using the same key yields the same value
|
||||||
|
b = mx.random.randint(low, high, shape, key=key)
|
||||||
|
self.assertListEqual(a.tolist(), b.tolist())
|
||||||
|
|
||||||
|
shape = [3, 4]
|
||||||
|
low = mx.reshape(mx.array([0] * 3), [3, 1])
|
||||||
|
high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])
|
||||||
|
|
||||||
|
a = mx.random.randint(low, high, shape)
|
||||||
|
self.assertEqual(a.shape, shape)
|
||||||
|
|
||||||
|
a = mx.random.randint(-10, 10, [1000, 1000])
|
||||||
|
self.assertTrue(mx.all(-10 <= a).item() and mx.all(a < 10).item())
|
||||||
|
|
||||||
|
a = mx.random.randint(10, -10, [1000, 1000])
|
||||||
|
self.assertTrue(mx.all(a == 10).item())
|
||||||
|
|
||||||
|
def test_bernoulli(self):
|
||||||
|
a = mx.random.bernoulli()
|
||||||
|
self.assertEqual(a.shape, [])
|
||||||
|
self.assertEqual(a.dtype, mx.bool_)
|
||||||
|
|
||||||
|
a = mx.random.bernoulli(mx.array(0.5), [5])
|
||||||
|
self.assertEqual(a.shape, [5])
|
||||||
|
|
||||||
|
a = mx.random.bernoulli(mx.array([2.0, -2.0]))
|
||||||
|
self.assertEqual(a.tolist(), [True, False])
|
||||||
|
self.assertEqual(a.shape, [2])
|
||||||
|
|
||||||
|
p = mx.array([0.1, 0.2, 0.3])
|
||||||
|
mx.reshape(p, [1, 3])
|
||||||
|
x = mx.random.bernoulli(p, [4, 3])
|
||||||
|
self.assertEqual(x.shape, [4, 3])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.random.bernoulli(p, [2]) # Bad shape
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.random.bernoulli(0, [2]) # Bad type
|
||||||
|
|
||||||
|
def test_truncated_normal(self):
|
||||||
|
a = mx.random.truncated_normal(-2.0, 2.0)
|
||||||
|
self.assertEqual(a.size, 1)
|
||||||
|
self.assertEqual(a.dtype, mx.float32)
|
||||||
|
|
||||||
|
a = mx.random.truncated_normal(mx.array([]), mx.array([]))
|
||||||
|
self.assertEqual(a.dtype, mx.float32)
|
||||||
|
self.assertEqual(a.size, 0)
|
||||||
|
|
||||||
|
lower = mx.reshape(mx.array([-2.0, 0.0]), [1, 2])
|
||||||
|
upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])
|
||||||
|
a = mx.random.truncated_normal(lower, upper)
|
||||||
|
|
||||||
|
self.assertEqual(a.shape, [3, 2])
|
||||||
|
self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())
|
||||||
|
|
||||||
|
a = mx.random.truncated_normal(2.0, -2.0)
|
||||||
|
self.assertTrue(mx.all(a == 2.0).item())
|
||||||
|
|
||||||
|
a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])
|
||||||
|
self.assertEqual(a.shape, [542, 399])
|
||||||
|
|
||||||
|
lower = mx.array([-2.0, -1.0])
|
||||||
|
higher = mx.array([1.0, 2.0, 3.0])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.random.truncated_normal(lower, higher) # Bad shape
|
||||||
|
|
||||||
|
def test_gumbel(self):
|
||||||
|
samples = mx.random.gumbel(shape=(100, 100))
|
||||||
|
self.assertEqual(samples.shape, [100, 100])
|
||||||
|
self.assertEqual(samples.dtype, mx.float32)
|
||||||
|
mean = 0.5772
|
||||||
|
# Std deviation of the sample mean is small (<0.02),
|
||||||
|
# so this test is pretty conservative
|
||||||
|
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
|
||||||
|
|
||||||
|
def test_categorical(self):
|
||||||
|
logits = mx.zeros((10, 20))
|
||||||
|
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
|
||||||
|
self.assertEqual(mx.random.categorical(logits, 0).shape, [20])
|
||||||
|
self.assertEqual(mx.random.categorical(logits, 1).shape, [10])
|
||||||
|
|
||||||
|
out = mx.random.categorical(logits)
|
||||||
|
self.assertEqual(out.shape, [10])
|
||||||
|
self.assertEqual(out.dtype, mx.uint32)
|
||||||
|
self.assertTrue(mx.max(out).item() < 20)
|
||||||
|
|
||||||
|
out = mx.random.categorical(logits, 0, [5, 20])
|
||||||
|
self.assertEqual(out.shape, [5, 20])
|
||||||
|
self.assertTrue(mx.max(out).item() < 10)
|
||||||
|
|
||||||
|
out = mx.random.categorical(logits, 1, num_samples=7)
|
||||||
|
self.assertEqual(out.shape, [10, 7])
|
||||||
|
out = mx.random.categorical(logits, 0, num_samples=7)
|
||||||
|
self.assertEqual(out.shape, [20, 7])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
26
python/tests/test_tree.py
Normal file
26
python/tests/test_tree.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.utils
|
||||||
|
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||||
|
def test_tree_map(self):
|
||||||
|
tree = {"a": 0, "b": 1, "c": 2}
|
||||||
|
tree = mlx.utils.tree_map(lambda x: x + 1, tree)
|
||||||
|
|
||||||
|
expected_tree = {"a": 1, "b": 2, "c": 3}
|
||||||
|
self.assertEqual(tree, expected_tree)
|
||||||
|
|
||||||
|
def test_tree_flatten(self):
|
||||||
|
tree = [{"a": 1, "b": 2}, "c"]
|
||||||
|
vals = (1, 2, "c")
|
||||||
|
flat_tree = mlx.utils.tree_flatten(tree)
|
||||||
|
self.assertEqual(list(zip(*flat_tree))[1], vals)
|
||||||
|
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
167
python/tests/test_vmap.py
Normal file
167
python/tests/test_vmap.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestVmap(mlx_tests.MLXTestCase):
|
||||||
|
def test_basics(self):
|
||||||
|
# Can't vmap over scalars
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(mx.exp)(mx.array(1.0))
|
||||||
|
|
||||||
|
# Invalid input
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(mx.exp)("hello")
|
||||||
|
|
||||||
|
# Invalid axes
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(mx.exp, in_axes="hello")(mx.array([0, 1]))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(mx.exp, in_axes=2)(mx.array([0, 1]))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(mx.exp, out_axes="hello")(mx.array([0, 1]))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(mx.exp, out_axes=2)(mx.array([0, 1]))
|
||||||
|
|
||||||
|
def test_unary(self):
|
||||||
|
ops = [
|
||||||
|
"abs",
|
||||||
|
"cos",
|
||||||
|
"erf",
|
||||||
|
"erfinv",
|
||||||
|
"exp",
|
||||||
|
"log",
|
||||||
|
"log1p",
|
||||||
|
"log2",
|
||||||
|
"log10",
|
||||||
|
"logical_not",
|
||||||
|
"negative",
|
||||||
|
"reciprocal",
|
||||||
|
"rsqrt",
|
||||||
|
"sigmoid",
|
||||||
|
"sign",
|
||||||
|
"sin",
|
||||||
|
"sqrt",
|
||||||
|
"square",
|
||||||
|
]
|
||||||
|
ops = ["erfinv"]
|
||||||
|
for opname in ops:
|
||||||
|
with self.subTest(op=opname):
|
||||||
|
op = getattr(mx, opname)
|
||||||
|
x = mx.arange(5)
|
||||||
|
y = mx.vmap(op)(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||||
|
|
||||||
|
x = mx.arange(8).reshape(2, 4)
|
||||||
|
y = mx.vmap(op)(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||||
|
|
||||||
|
y = mx.vmap(op, in_axes=1, out_axes=1)(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||||
|
|
||||||
|
def test_binary(self):
|
||||||
|
ops = [
|
||||||
|
"add",
|
||||||
|
"divide",
|
||||||
|
"equal",
|
||||||
|
"greater",
|
||||||
|
"greater_equal",
|
||||||
|
"less",
|
||||||
|
"less_equal",
|
||||||
|
"logaddexp",
|
||||||
|
"maximum",
|
||||||
|
"minimum",
|
||||||
|
"multiply",
|
||||||
|
"power",
|
||||||
|
"subtract",
|
||||||
|
]
|
||||||
|
for opname in ops:
|
||||||
|
with self.subTest(op=opname):
|
||||||
|
op = getattr(mx, opname)
|
||||||
|
x = mx.random.uniform(shape=(5,))
|
||||||
|
y = mx.random.uniform(shape=(5,))
|
||||||
|
out = mx.vmap(op)(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 4))
|
||||||
|
y = mx.random.uniform(shape=(2, 4))
|
||||||
|
out = mx.vmap(op)(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||||
|
|
||||||
|
out = mx.vmap(op, in_axes=(0, 0), out_axes=0)(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||||
|
|
||||||
|
y = mx.random.uniform(shape=(4, 2))
|
||||||
|
out = mx.vmap(op, in_axes=(0, 1), out_axes=0)(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out, op(x, y.T)))
|
||||||
|
|
||||||
|
out = mx.vmap(op, in_axes=(0, 1), out_axes=1)(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out, op(x, y.T).T))
|
||||||
|
|
||||||
|
def test_tree(self):
|
||||||
|
def my_fun(tree):
|
||||||
|
return (tree["a"] + tree["b"][0]) * tree["b"][1]
|
||||||
|
|
||||||
|
tree = {
|
||||||
|
"a": mx.random.uniform(shape=(2, 4)),
|
||||||
|
"b": (
|
||||||
|
mx.random.uniform(shape=(2, 4)),
|
||||||
|
mx.random.uniform(shape=(2, 4)),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
out = mx.vmap(my_fun)(tree)
|
||||||
|
expected = my_fun(tree)
|
||||||
|
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
|
||||||
|
|
||||||
|
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
|
||||||
|
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||||
|
|
||||||
|
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
|
||||||
|
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||||
|
|
||||||
|
tree = {
|
||||||
|
"a": mx.random.uniform(shape=(2, 4)),
|
||||||
|
"b": (
|
||||||
|
mx.random.uniform(shape=(4, 2)),
|
||||||
|
mx.random.uniform(shape=(4, 2)),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
|
||||||
|
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
|
||||||
|
self.assertTrue(mx.array_equal(out, expected))
|
||||||
|
|
||||||
|
def my_fun(x, y):
|
||||||
|
return {"a": x + y, "b": x * y}
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 4))
|
||||||
|
y = mx.random.uniform(shape=(2, 4))
|
||||||
|
out = mx.vmap(my_fun, in_axes=0, out_axes=0)(x, y)
|
||||||
|
expected = my_fun(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out["a"], expected["a"]))
|
||||||
|
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(my_fun, in_axes=0, out_axes=(0, 1))(x, y)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(my_fun, in_axes=0, out_axes={"a": 0, "c": 1})(x, y)
|
||||||
|
|
||||||
|
out = mx.vmap(my_fun, in_axes=0, out_axes={"a": 1, "b": 0})(x, y)
|
||||||
|
expected = my_fun(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
|
||||||
|
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
127
setup.py
Normal file
127
setup.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import sysconfig
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from setuptools import Extension, setup, find_namespace_packages
|
||||||
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
|
|
||||||
|
# A CMakeExtension needs a sourcedir instead of a file list.
|
||||||
|
# The name must be the _single_ output extension from the CMake build.
|
||||||
|
# If you need multiple extensions, see scikit-build.
|
||||||
|
class CMakeExtension(Extension):
|
||||||
|
def __init__(self, name: str, sourcedir: str = "") -> None:
|
||||||
|
super().__init__(name, sources=[])
|
||||||
|
self.sourcedir = os.fspath(Path(sourcedir).resolve())
|
||||||
|
|
||||||
|
|
||||||
|
class CMakeBuild(build_ext):
|
||||||
|
def build_extension(self, ext: CMakeExtension) -> None:
|
||||||
|
# Must be in this form due to bug in .resolve() only fixed in Python 3.10+
|
||||||
|
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
|
||||||
|
extdir = ext_fullpath.parent.resolve()
|
||||||
|
|
||||||
|
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
|
||||||
|
cfg = "Debug" if debug else "Release"
|
||||||
|
|
||||||
|
# CMake lets you override the generator - we need to check this.
|
||||||
|
# Can be set with Conda-Build, for example.
|
||||||
|
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
|
||||||
|
|
||||||
|
# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
|
||||||
|
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
|
||||||
|
# from Python.
|
||||||
|
cmake_args = [
|
||||||
|
f"-DCMAKE_INSTALL_PREFIX={extdir}{os.sep}",
|
||||||
|
f"-DCMAKE_BUILD_TYPE={cfg}",
|
||||||
|
"-DBUILD_SHARED_LIBS=ON",
|
||||||
|
"-DMLX_BUILD_PYTHON_BINDINGS=ON",
|
||||||
|
"-DMLX_BUILD_TESTS=OFF",
|
||||||
|
"-DMLX_BUILD_BENCHMARKS=OFF",
|
||||||
|
"-DMLX_BUILD_EXAMPLES=OFF",
|
||||||
|
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
||||||
|
]
|
||||||
|
build_args = []
|
||||||
|
# Adding CMake arguments set as environment variable
|
||||||
|
# (needed e.g. to build for ARM OSx on conda-forge)
|
||||||
|
if "CMAKE_ARGS" in os.environ:
|
||||||
|
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
||||||
|
|
||||||
|
# Pass version to C++
|
||||||
|
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
if sys.platform.startswith("darwin"):
|
||||||
|
# Cross-compile support for macOS - respect ARCHFLAGS if set
|
||||||
|
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
|
||||||
|
if archs:
|
||||||
|
cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
|
||||||
|
|
||||||
|
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
||||||
|
# across all generators.
|
||||||
|
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||||
|
# self.parallel is a Python 3 only way to set parallel jobs by hand
|
||||||
|
# using -j in the build_ext call, not supported by pip or PyPA-build.
|
||||||
|
if hasattr(self, "parallel") and self.parallel:
|
||||||
|
# CMake 3.12+ only.
|
||||||
|
build_args += [f"-j{self.parallel}"]
|
||||||
|
|
||||||
|
build_temp = Path(self.build_temp) / ext.name
|
||||||
|
if not build_temp.exists():
|
||||||
|
build_temp.mkdir(parents=True)
|
||||||
|
|
||||||
|
subprocess.run(
|
||||||
|
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
||||||
|
)
|
||||||
|
subprocess.run(
|
||||||
|
["cmake", "--build", ".", "--target", "install", *build_args],
|
||||||
|
cwd=build_temp,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure to copy mlx.metallib for inplace builds
|
||||||
|
def run(self):
|
||||||
|
super().run()
|
||||||
|
|
||||||
|
# Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102
|
||||||
|
if self.inplace:
|
||||||
|
for ext in self.extensions:
|
||||||
|
if ext.name == "mlx.core":
|
||||||
|
# Resolve inplace package dir
|
||||||
|
build_py = self.get_finalized_command("build_py")
|
||||||
|
inplace_file, regular_file = self._get_inplace_equivalent(
|
||||||
|
build_py, ext
|
||||||
|
)
|
||||||
|
|
||||||
|
inplace_dir = str(Path(inplace_file).parent.resolve())
|
||||||
|
regular_dir = str(Path(regular_file).parent.resolve())
|
||||||
|
|
||||||
|
self.copy_tree(regular_dir, inplace_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# The information here can also be placed in setup.cfg - better separation of
|
||||||
|
# logic and declaration, and simpler if you include description/version in a file.
|
||||||
|
if __name__ == "__main__":
|
||||||
|
packages = find_namespace_packages(
|
||||||
|
where="python", exclude=["src", "tests", "tests.*"]
|
||||||
|
)
|
||||||
|
package_dir = {"": "python"}
|
||||||
|
package_data = {"mlx": ["lib/*", "include/*", "share/*"]}
|
||||||
|
setup(
|
||||||
|
name="mlx",
|
||||||
|
version="0.0.2",
|
||||||
|
author="MLX Contributors",
|
||||||
|
author_email="mlx@group.apple.com",
|
||||||
|
description="A framework for machine learning on Apple Silicon.",
|
||||||
|
long_description="",
|
||||||
|
packages=packages,
|
||||||
|
package_dir=package_dir,
|
||||||
|
package_data=package_data,
|
||||||
|
include_package_data=True,
|
||||||
|
ext_modules=[CMakeExtension("mlx.core")],
|
||||||
|
cmdclass={"build_ext": CMakeBuild},
|
||||||
|
zip_safe=False,
|
||||||
|
python_requires=">=3.7",
|
||||||
|
)
|
41
tests/allocator_tests.cpp
Normal file
41
tests/allocator_tests.cpp
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test simple allocations") {
|
||||||
|
{
|
||||||
|
auto buffer = allocator::malloc(sizeof(float));
|
||||||
|
auto fptr = static_cast<float*>(buffer.raw_ptr());
|
||||||
|
*fptr = 0.5f;
|
||||||
|
CHECK_EQ(*fptr, 0.5f);
|
||||||
|
allocator::free(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buffer = allocator::malloc(128 * sizeof(int));
|
||||||
|
int* ptr = static_cast<int*>(buffer.raw_ptr());
|
||||||
|
for (int i = 0; i < 128; ++i) {
|
||||||
|
ptr[i] = i;
|
||||||
|
}
|
||||||
|
allocator::free(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buffer = allocator::malloc(0);
|
||||||
|
allocator::free(buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test large allocations") {
|
||||||
|
size_t size = 1 << 30;
|
||||||
|
for (int i = 0; i < 100; ++i) {
|
||||||
|
auto buffer = allocator::malloc(size);
|
||||||
|
allocator::free(buffer);
|
||||||
|
}
|
||||||
|
// Shouldn't be able to allocate an exabyte anytime soon.
|
||||||
|
CHECK_THROWS_AS(allocator::malloc(1ull << 60), std::runtime_error);
|
||||||
|
}
|
589
tests/array_tests.cpp
Normal file
589
tests/array_tests.cpp
Normal file
@ -0,0 +1,589 @@
|
|||||||
|
#include <climits>
|
||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test array basics") {
|
||||||
|
// Scalar
|
||||||
|
array x(1.0);
|
||||||
|
CHECK_EQ(x.size(), 1);
|
||||||
|
CHECK_EQ(x.ndim(), 0);
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>{});
|
||||||
|
CHECK_THROWS_AS(x.shape(0), std::out_of_range);
|
||||||
|
CHECK_THROWS_AS(x.shape(-1), std::out_of_range);
|
||||||
|
CHECK_EQ(x.strides(), std::vector<size_t>{});
|
||||||
|
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||||
|
CHECK_EQ(x.nbytes(), sizeof(float));
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK_EQ(x.item<float>(), 1.0);
|
||||||
|
|
||||||
|
// Scalar with specified type
|
||||||
|
x = array(1, float32);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK_EQ(x.item<float>(), 1.0);
|
||||||
|
|
||||||
|
// Scalar with specified type
|
||||||
|
x = array(1, bool_);
|
||||||
|
CHECK_EQ(x.dtype(), bool_);
|
||||||
|
CHECK_EQ(x.itemsize(), sizeof(bool));
|
||||||
|
CHECK_EQ(x.nbytes(), sizeof(bool));
|
||||||
|
CHECK_EQ(x.item<bool>(), true);
|
||||||
|
|
||||||
|
// Check shaped arrays
|
||||||
|
x = array({1.0});
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK_EQ(x.size(), 1);
|
||||||
|
CHECK_EQ(x.ndim(), 1);
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>{1});
|
||||||
|
CHECK_EQ(x.shape(0), 1);
|
||||||
|
CHECK_EQ(x.shape(-1), 1);
|
||||||
|
CHECK_THROWS_AS(x.shape(1), std::out_of_range);
|
||||||
|
CHECK_THROWS_AS(x.shape(-2), std::out_of_range);
|
||||||
|
CHECK_EQ(x.strides(), std::vector<size_t>{1});
|
||||||
|
CHECK_EQ(x.item<float>(), 1.0);
|
||||||
|
|
||||||
|
// Check empty array
|
||||||
|
x = array({});
|
||||||
|
CHECK_EQ(x.size(), 0);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||||
|
CHECK_EQ(x.nbytes(), 0);
|
||||||
|
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
|
||||||
|
|
||||||
|
x = array({1.0, 1.0});
|
||||||
|
CHECK_EQ(x.size(), 2);
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>{2});
|
||||||
|
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||||
|
CHECK_EQ(x.nbytes(), x.itemsize() * x.size());
|
||||||
|
|
||||||
|
// Accessing item in non-scalar array throws
|
||||||
|
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
|
||||||
|
|
||||||
|
x = array({1.0, 1.0, 1.0}, {1, 3});
|
||||||
|
CHECK(x.size() == 3);
|
||||||
|
CHECK(x.shape() == std::vector<int>{1, 3});
|
||||||
|
CHECK(x.strides() == std::vector<size_t>{3, 1});
|
||||||
|
|
||||||
|
// Test wrong size/shapes throw:
|
||||||
|
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 4}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 2}), std::invalid_argument);
|
||||||
|
|
||||||
|
// Test array ids work as expected
|
||||||
|
x = array(1.0);
|
||||||
|
auto y = x;
|
||||||
|
CHECK_EQ(y.id(), x.id());
|
||||||
|
array z(2.0);
|
||||||
|
CHECK_NE(z.id(), x.id());
|
||||||
|
z = x;
|
||||||
|
CHECK_EQ(z.id(), x.id());
|
||||||
|
|
||||||
|
// Array creation from pointer
|
||||||
|
float data[] = {0.0, 1.0, 2.0, 3.0};
|
||||||
|
x = array(data, {4});
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK(array_equal(x, array({0.0, 1.0, 2.0, 3.0})).item<bool>());
|
||||||
|
|
||||||
|
// Array creation from vectors
|
||||||
|
{
|
||||||
|
std::vector<int> data = {0, 1, 2, 3};
|
||||||
|
x = array(data.begin(), {4});
|
||||||
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<bool> data = {false, true, false, true};
|
||||||
|
x = array(data.begin(), {4});
|
||||||
|
CHECK_EQ(x.dtype(), bool_);
|
||||||
|
CHECK(array_equal(x, array({false, true, false, true})).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test array types") {
|
||||||
|
#define basic_dtype_test(T, mlx_type) \
|
||||||
|
T val = 42; \
|
||||||
|
array x(val); \
|
||||||
|
CHECK_EQ(x.dtype(), mlx_type); \
|
||||||
|
CHECK_EQ(x.item<T>(), val); \
|
||||||
|
x = array({val, val}); \
|
||||||
|
CHECK_EQ(x.dtype(), mlx_type);
|
||||||
|
|
||||||
|
// bool_
|
||||||
|
{
|
||||||
|
array x(true);
|
||||||
|
CHECK_EQ(x.dtype(), bool_);
|
||||||
|
CHECK_EQ(x.item<bool>(), true);
|
||||||
|
|
||||||
|
x = array({true, false});
|
||||||
|
CHECK_EQ(x.dtype(), bool_);
|
||||||
|
|
||||||
|
x = array({true, false}, float32);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK(array_equal(x, array({1.0f, 0.0f})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// uint8
|
||||||
|
{ basic_dtype_test(uint8_t, uint8); }
|
||||||
|
|
||||||
|
// uint16
|
||||||
|
{ basic_dtype_test(uint16_t, uint16); }
|
||||||
|
|
||||||
|
// uint32
|
||||||
|
{ basic_dtype_test(uint32_t, uint32); }
|
||||||
|
|
||||||
|
// uint64
|
||||||
|
{ basic_dtype_test(uint64_t, uint64); }
|
||||||
|
|
||||||
|
// int8
|
||||||
|
{ basic_dtype_test(int8_t, int8); }
|
||||||
|
|
||||||
|
// int16
|
||||||
|
{ basic_dtype_test(int16_t, int16); }
|
||||||
|
|
||||||
|
// int32
|
||||||
|
{ basic_dtype_test(int32_t, int32); }
|
||||||
|
|
||||||
|
// int64
|
||||||
|
{ basic_dtype_test(int64_t, int64); }
|
||||||
|
|
||||||
|
// float16
|
||||||
|
{ basic_dtype_test(float16_t, float16); }
|
||||||
|
|
||||||
|
// float32
|
||||||
|
{ basic_dtype_test(float, float32); }
|
||||||
|
|
||||||
|
// bfloat16
|
||||||
|
{ basic_dtype_test(bfloat16_t, bfloat16); }
|
||||||
|
|
||||||
|
// uint32
|
||||||
|
{
|
||||||
|
uint32_t val = UINT_MAX;
|
||||||
|
array x(val);
|
||||||
|
CHECK_EQ(x.dtype(), uint32);
|
||||||
|
CHECK_EQ(x.item<uint32_t>(), val);
|
||||||
|
|
||||||
|
x = array({1u, 2u});
|
||||||
|
CHECK_EQ(x.dtype(), uint32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// int32
|
||||||
|
{
|
||||||
|
array x(-1);
|
||||||
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
CHECK_EQ(x.item<int>(), -1);
|
||||||
|
|
||||||
|
x = array({-1, 2});
|
||||||
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
|
||||||
|
std::vector<int> data{0, 1, 2};
|
||||||
|
x = array(data.data(), {static_cast<int>(data.size())}, bool_);
|
||||||
|
CHECK_EQ(x.dtype(), bool_);
|
||||||
|
CHECK(array_equal(x, array({false, true, true})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// int64
|
||||||
|
{
|
||||||
|
int64_t val = static_cast<int64_t>(INT_MIN) - 1;
|
||||||
|
array x(val);
|
||||||
|
CHECK_EQ(x.dtype(), int64);
|
||||||
|
CHECK_EQ(x.item<int64_t>(), val);
|
||||||
|
|
||||||
|
x = array({val, val});
|
||||||
|
CHECK_EQ(x.dtype(), int64);
|
||||||
|
}
|
||||||
|
|
||||||
|
// float32
|
||||||
|
{
|
||||||
|
array x(3.14f);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK_EQ(x.item<float>(), 3.14f);
|
||||||
|
|
||||||
|
x = array(1.25);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK_EQ(x.item<float>(), 1.25f);
|
||||||
|
|
||||||
|
x = array({1.0f, 2.0f});
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
|
||||||
|
x = array({1.0, 2.0});
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
|
||||||
|
std::vector<double> data{1.0, 2.0, 4.0};
|
||||||
|
x = array(data.data(), {static_cast<int>(data.size())});
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK(array_equal(x, array({1.0f, 2.0f, 4.0f})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// complex64
|
||||||
|
{
|
||||||
|
complex64_t v = {1.0f, 1.0f};
|
||||||
|
array x(v);
|
||||||
|
CHECK_EQ(x.dtype(), complex64);
|
||||||
|
CHECK_EQ(x.item<complex64_t>(), v);
|
||||||
|
|
||||||
|
array y(std::complex<float>{1.0f, 1.0f});
|
||||||
|
CHECK_EQ(x.dtype(), complex64);
|
||||||
|
CHECK_EQ(x.item<complex64_t>(), v);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef basic_dtype_test
|
||||||
|
|
||||||
|
#define basic_dtype_str_test(s, dtype) \
|
||||||
|
CHECK_EQ(s, dtype_to_array_protocol(dtype)); \
|
||||||
|
CHECK_EQ(dtype_from_array_protocol(s), dtype);
|
||||||
|
|
||||||
|
// To and from str
|
||||||
|
{
|
||||||
|
basic_dtype_str_test("|b1", bool_);
|
||||||
|
basic_dtype_str_test("|u1", uint8);
|
||||||
|
basic_dtype_str_test("<u2", uint16);
|
||||||
|
basic_dtype_str_test("<u4", uint32);
|
||||||
|
basic_dtype_str_test("<u8", uint64);
|
||||||
|
basic_dtype_str_test("|i1", int8);
|
||||||
|
basic_dtype_str_test("<i2", int16);
|
||||||
|
basic_dtype_str_test("<i4", int32);
|
||||||
|
basic_dtype_str_test("<i8", int64);
|
||||||
|
basic_dtype_str_test("<f2", float16);
|
||||||
|
basic_dtype_str_test("<f4", float32);
|
||||||
|
basic_dtype_str_test("<V2", bfloat16);
|
||||||
|
basic_dtype_str_test("<c8", complex64);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef basic_dtype_str_test
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test array metadata") {
|
||||||
|
array x(1.0f);
|
||||||
|
CHECK_EQ(x.data_size(), 1);
|
||||||
|
CHECK_EQ(x.flags().contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({1.0f}, {1, 1, 1});
|
||||||
|
CHECK_EQ(x.data_size(), 1);
|
||||||
|
CHECK_EQ(x.flags().contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({1.0f, 1.0f}, {1, 2});
|
||||||
|
CHECK_EQ(x.data_size(), 2);
|
||||||
|
CHECK_EQ(x.flags().contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = zeros({1, 1, 4});
|
||||||
|
eval(x);
|
||||||
|
CHECK_EQ(x.data_size(), 4);
|
||||||
|
CHECK_EQ(x.flags().contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = zeros({2, 4});
|
||||||
|
eval(x);
|
||||||
|
CHECK_EQ(x.data_size(), 8);
|
||||||
|
CHECK_EQ(x.flags().contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(x.flags().col_contiguous, false);
|
||||||
|
|
||||||
|
x = array(1.0f);
|
||||||
|
auto y = broadcast_to(x, {1, 1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
y = broadcast_to(x, {2, 8, 10});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, false);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, false);
|
||||||
|
|
||||||
|
y = broadcast_to(x, {1, 0});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 0);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
y = broadcast_to(zeros({4, 2, 1}), {4, 2, 0});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 0);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array(1.0f);
|
||||||
|
y = transpose(x);
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = ones({1, 1, 1});
|
||||||
|
y = transpose(x);
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = ones({1, 1, 1});
|
||||||
|
y = transpose(x, {0, 1, 2});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = ones({1, 1, 1});
|
||||||
|
y = transpose(x, {1, 2, 0});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = ones({4, 1});
|
||||||
|
y = transpose(x);
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 4);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = ones({2, 3, 4});
|
||||||
|
y = transpose(x);
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 24);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, false);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
y = transpose(x, {0, 2, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 24);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, false);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, false);
|
||||||
|
|
||||||
|
y = transpose(transpose(x, {0, 2, 1}), {0, 2, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 24);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, false);
|
||||||
|
|
||||||
|
x = array(1.0f);
|
||||||
|
y = reshape(x, {1, 1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = ones({2, 4});
|
||||||
|
y = reshape(x, {8});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 8);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
y = reshape(x, {8, 1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 8);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
y = reshape(x, {1, 8, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 8);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = ones({12});
|
||||||
|
y = reshape(x, {2, 3, 2});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 12);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, false);
|
||||||
|
|
||||||
|
x = array(1.0f);
|
||||||
|
y = slice(x, {}, {});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({1.0f});
|
||||||
|
y = slice(x, {-10}, {10}, {10});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||||
|
y = slice(x, {0, 0}, {1, 3}, {1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 3);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||||
|
y = slice(x, {0, 0}, {1, 3}, {1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 3);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||||
|
y = slice(x, {0, 0}, {0, 3}, {1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 0);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||||
|
y = slice(x, {0, 0}, {1, 2}, {1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 2);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||||
|
y = slice(x, {0, 0}, {1, 2}, {2, 3});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4});
|
||||||
|
y = slice(x, {0, 0}, {1, 4}, {1, 2});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{1, 2});
|
||||||
|
CHECK_EQ(y.flags().contiguous, false);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, false);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, false);
|
||||||
|
|
||||||
|
x = broadcast_to(array(1.0f), {4, 10});
|
||||||
|
y = slice(x, {0, 0}, {4, 10}, {2, 2});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
|
||||||
|
CHECK_EQ(y.data_size(), 1);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, false);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, false);
|
||||||
|
|
||||||
|
x = broadcast_to(array({1.0f, 2.0f}), {4, 2});
|
||||||
|
y = slice(x, {0, 0}, {1, 2}, {1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 2);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
y = slice(x, {1, 0}, {2, 2}, {1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 2);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
|
||||||
|
y = slice(x, {0, 0}, {2, 2}, {1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 4);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, false);
|
||||||
|
|
||||||
|
y = slice(transpose(x), {0, 0}, {2, 2}, {1, 1});
|
||||||
|
eval(y);
|
||||||
|
CHECK_EQ(y.data_size(), 4);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, false);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
|
||||||
|
x = ones({2, 4});
|
||||||
|
auto out = split(x, 2);
|
||||||
|
eval(out);
|
||||||
|
for (auto y : out) {
|
||||||
|
CHECK_EQ(y.data_size(), 4);
|
||||||
|
CHECK_EQ(y.flags().contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, true);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, true);
|
||||||
|
}
|
||||||
|
out = split(x, 4, 1);
|
||||||
|
eval(out);
|
||||||
|
for (auto y : out) {
|
||||||
|
CHECK_EQ(y.flags().contiguous, false);
|
||||||
|
CHECK_EQ(y.flags().row_contiguous, false);
|
||||||
|
CHECK_EQ(y.flags().col_contiguous, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test array iteration") {
|
||||||
|
// Dim 0 arrays
|
||||||
|
auto arr = array(1);
|
||||||
|
CHECK_THROWS(arr.begin());
|
||||||
|
|
||||||
|
// Iterated arrays are read only
|
||||||
|
CHECK(std::is_const_v<decltype(*arr.begin())>);
|
||||||
|
|
||||||
|
arr = array({1, 2, 3, 4, 5});
|
||||||
|
int i = 0;
|
||||||
|
for (auto a : arr) {
|
||||||
|
i++;
|
||||||
|
CHECK_EQ(a.item<int>(), i);
|
||||||
|
}
|
||||||
|
CHECK_EQ(i, 5);
|
||||||
|
|
||||||
|
arr = array({1, 2, 3, 4}, {2, 2});
|
||||||
|
CHECK(array_equal(*arr.begin(), array({1, 2})).item<bool>());
|
||||||
|
CHECK(array_equal(*(arr.begin() + 1), array({3, 4})).item<bool>());
|
||||||
|
CHECK_EQ(arr.begin() + 2, arr.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test array shared buffer") {
|
||||||
|
std::vector<int> shape = {2, 2};
|
||||||
|
int n_elem = shape[0] * shape[1];
|
||||||
|
|
||||||
|
allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float));
|
||||||
|
void* buf_b_ptr = buf_b.raw_ptr();
|
||||||
|
float* float_buf_b = (float*)buf_b_ptr;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_elem; i++) {
|
||||||
|
float_buf_b[i] = 2.;
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_EQ(float_buf_b[0], ((float*)buf_b_ptr)[0]);
|
||||||
|
|
||||||
|
auto deleter = [float_buf_b](allocator::Buffer buf) {
|
||||||
|
CHECK_EQ(float_buf_b, (float*)buf.raw_ptr());
|
||||||
|
CHECK_EQ(float_buf_b[0], ((float*)buf.raw_ptr())[0]);
|
||||||
|
allocator::free(buf);
|
||||||
|
};
|
||||||
|
|
||||||
|
array a = ones(shape, float32);
|
||||||
|
array b = array(buf_b, shape, float32, deleter);
|
||||||
|
|
||||||
|
eval(a + b);
|
||||||
|
}
|
1192
tests/autograd_tests.cpp
Normal file
1192
tests/autograd_tests.cpp
Normal file
File diff suppressed because it is too large
Load Diff
33
tests/device_tests.cpp
Normal file
33
tests/device_tests.cpp
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test device placement") {
|
||||||
|
auto device = default_device();
|
||||||
|
Device d = metal::is_available() ? Device::gpu : Device::cpu;
|
||||||
|
if (std::getenv("DEVICE") == nullptr) {
|
||||||
|
CHECK_EQ(device, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
array x(1.0f);
|
||||||
|
array y(1.0f);
|
||||||
|
auto z = add(x, y, default_device());
|
||||||
|
if (metal::is_available()) {
|
||||||
|
z = add(x, y, Device::gpu);
|
||||||
|
z = add(x, y, Device(Device::gpu, 0));
|
||||||
|
} else {
|
||||||
|
CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the default device to the CPU
|
||||||
|
set_default_device(Device::cpu);
|
||||||
|
CHECK_EQ(default_device(), Device::cpu);
|
||||||
|
|
||||||
|
// Revert
|
||||||
|
set_default_device(device);
|
||||||
|
}
|
97
tests/eval_tests.cpp
Normal file
97
tests/eval_tests.cpp
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test eval") {
|
||||||
|
{
|
||||||
|
array x(1.0);
|
||||||
|
array y(1);
|
||||||
|
array z(true);
|
||||||
|
eval({x, y, z});
|
||||||
|
CHECK_EQ(x.item<float>(), 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
array x(1.0);
|
||||||
|
array y = ones({2, 2});
|
||||||
|
array z(true);
|
||||||
|
eval({x, y, z});
|
||||||
|
CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test eval multiple") {
|
||||||
|
auto x = ones({10, 10});
|
||||||
|
auto y = ones({10, 10});
|
||||||
|
eval({x, y});
|
||||||
|
CHECK(array_equal(x, y).item<bool>());
|
||||||
|
|
||||||
|
auto a = x + y;
|
||||||
|
auto b = x - y;
|
||||||
|
eval({a, b});
|
||||||
|
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
|
||||||
|
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
||||||
|
|
||||||
|
x = ones({10, 10});
|
||||||
|
y = ones({10, 10});
|
||||||
|
eval(x, y);
|
||||||
|
CHECK(array_equal(x, y).item<bool>());
|
||||||
|
|
||||||
|
a = x + y;
|
||||||
|
b = x - y;
|
||||||
|
eval(a, b);
|
||||||
|
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
|
||||||
|
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test eval with tracer") {
|
||||||
|
auto x = array(1);
|
||||||
|
x.set_tracer(true);
|
||||||
|
|
||||||
|
// Ok, x is not a node
|
||||||
|
eval(x);
|
||||||
|
|
||||||
|
x = ones({2, 3});
|
||||||
|
x.set_tracer(true);
|
||||||
|
CHECK_THROWS(eval(x));
|
||||||
|
|
||||||
|
// Ok retain_graph=true
|
||||||
|
eval({x}, true);
|
||||||
|
|
||||||
|
// Make sure all arguments are checked
|
||||||
|
auto y = ones({2, 3});
|
||||||
|
CHECK_THROWS(eval(x, y));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test eval graph retention") {
|
||||||
|
auto x = array(1);
|
||||||
|
auto y = array(2);
|
||||||
|
auto z = x + y;
|
||||||
|
eval({z}, true);
|
||||||
|
CHECK(z.has_primitive());
|
||||||
|
CHECK(z.is_evaled());
|
||||||
|
CHECK_EQ(z.item<int>(true), 3);
|
||||||
|
CHECK(z.has_primitive());
|
||||||
|
CHECK(z.is_evaled());
|
||||||
|
|
||||||
|
CHECK_EQ(z.item<int>(), 3);
|
||||||
|
CHECK(!z.has_primitive());
|
||||||
|
CHECK(z.is_evaled());
|
||||||
|
|
||||||
|
z = x + y;
|
||||||
|
auto a = z + x;
|
||||||
|
auto b = a + y;
|
||||||
|
eval({b}, true);
|
||||||
|
CHECK(z.has_primitive());
|
||||||
|
CHECK(z.is_evaled());
|
||||||
|
CHECK(a.has_primitive());
|
||||||
|
CHECK(a.is_evaled());
|
||||||
|
|
||||||
|
eval({b}, false);
|
||||||
|
CHECK(!z.has_primitive());
|
||||||
|
CHECK(z.is_evaled());
|
||||||
|
CHECK(!a.has_primitive());
|
||||||
|
CHECK(a.is_evaled());
|
||||||
|
}
|
81
tests/load_tests.cpp
Normal file
81
tests/load_tests.cpp
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
#include <filesystem>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
std::string get_temp_file(const std::string& name) {
|
||||||
|
return std::filesystem::temp_directory_path().append(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test single array serialization") {
|
||||||
|
// Basic test
|
||||||
|
{
|
||||||
|
auto a = random::uniform(-5.f, 5.f, {2, 5, 12}, float32);
|
||||||
|
|
||||||
|
std::string file_path = get_temp_file("test_arr.npy");
|
||||||
|
|
||||||
|
save(file_path, a);
|
||||||
|
auto b = load(file_path);
|
||||||
|
|
||||||
|
CHECK_EQ(a.dtype(), b.dtype());
|
||||||
|
CHECK_EQ(a.shape(), b.shape());
|
||||||
|
CHECK(array_equal(a, b).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other shapes
|
||||||
|
{
|
||||||
|
auto a = random::uniform(
|
||||||
|
-5.f,
|
||||||
|
5.f,
|
||||||
|
{
|
||||||
|
1,
|
||||||
|
},
|
||||||
|
float32);
|
||||||
|
|
||||||
|
std::string file_path = get_temp_file("test_arr_0.npy");
|
||||||
|
|
||||||
|
save(file_path, a);
|
||||||
|
auto b = load(file_path);
|
||||||
|
|
||||||
|
CHECK_EQ(a.dtype(), b.dtype());
|
||||||
|
CHECK_EQ(a.shape(), b.shape());
|
||||||
|
CHECK(array_equal(a, b).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto a = random::uniform(
|
||||||
|
-5.f,
|
||||||
|
5.f,
|
||||||
|
{
|
||||||
|
46,
|
||||||
|
},
|
||||||
|
float32);
|
||||||
|
|
||||||
|
std::string file_path = get_temp_file("test_arr_1.npy");
|
||||||
|
|
||||||
|
save(file_path, a);
|
||||||
|
auto b = load(file_path);
|
||||||
|
|
||||||
|
CHECK_EQ(a.dtype(), b.dtype());
|
||||||
|
CHECK_EQ(a.shape(), b.shape());
|
||||||
|
CHECK(array_equal(a, b).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto a = random::uniform(-5.f, 5.f, {5, 2, 1, 3, 4}, float32);
|
||||||
|
|
||||||
|
std::string file_path = get_temp_file("test_arr_2.npy");
|
||||||
|
|
||||||
|
save(file_path, a);
|
||||||
|
auto b = load(file_path);
|
||||||
|
|
||||||
|
CHECK_EQ(a.dtype(), b.dtype());
|
||||||
|
CHECK_EQ(a.shape(), b.shape());
|
||||||
|
CHECK(array_equal(a, b).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
119
tests/scheduler_tests.cpp
Normal file
119
tests/scheduler_tests.cpp
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test stream management") {
|
||||||
|
auto s1 = default_stream(default_device());
|
||||||
|
CHECK_EQ(s1.device, default_device());
|
||||||
|
|
||||||
|
auto s2 = new_stream(default_device());
|
||||||
|
CHECK_EQ(s2.device, default_device());
|
||||||
|
CHECK_NE(s1, s2);
|
||||||
|
|
||||||
|
// Check that default streams have the correct devices
|
||||||
|
if (metal::is_available()) {
|
||||||
|
auto s_gpu = default_stream(Device::gpu);
|
||||||
|
CHECK_EQ(s_gpu.device, Device::gpu);
|
||||||
|
} else {
|
||||||
|
CHECK_THROWS_AS(default_stream(Device::gpu), std::invalid_argument);
|
||||||
|
}
|
||||||
|
auto s_cpu = default_stream(Device::cpu);
|
||||||
|
CHECK_EQ(s_cpu.device, Device::cpu);
|
||||||
|
|
||||||
|
s_cpu = new_stream(Device::cpu);
|
||||||
|
CHECK_EQ(s_cpu.device, Device::cpu);
|
||||||
|
|
||||||
|
if (metal::is_available()) {
|
||||||
|
auto s_gpu = new_stream(Device::gpu);
|
||||||
|
CHECK_EQ(s_gpu.device, Device::gpu);
|
||||||
|
} else {
|
||||||
|
CHECK_THROWS_AS(new_stream(Device::gpu), std::invalid_argument);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test asynchronous launch") {
|
||||||
|
auto s1 = default_stream(default_device());
|
||||||
|
auto s2 = new_stream(default_device());
|
||||||
|
|
||||||
|
// Make sure streams execute asynchronously
|
||||||
|
int x = 1;
|
||||||
|
auto p1 = std::make_shared<std::promise<void>>();
|
||||||
|
auto p2 = std::make_shared<std::promise<void>>();
|
||||||
|
auto f1 = p1->get_future().share();
|
||||||
|
auto f2 = p2->get_future().share();
|
||||||
|
auto fn1 = [&x, p = std::move(p1)]() {
|
||||||
|
x++;
|
||||||
|
p->set_value();
|
||||||
|
};
|
||||||
|
auto fn2 = [&x, p = std::move(p2), f = std::move(f1)]() {
|
||||||
|
f.wait();
|
||||||
|
x *= 5;
|
||||||
|
p->set_value();
|
||||||
|
};
|
||||||
|
|
||||||
|
// fn2 is launched first and is waiting on fn1 but since
|
||||||
|
// they are on different streams there is no deadlock.
|
||||||
|
scheduler::enqueue(s2, std::move(fn2));
|
||||||
|
scheduler::enqueue(s1, std::move(fn1));
|
||||||
|
|
||||||
|
f2.wait();
|
||||||
|
|
||||||
|
CHECK_EQ(x, 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test stream placement") {
|
||||||
|
auto s1 = default_stream(default_device());
|
||||||
|
auto s2 = new_stream(default_device());
|
||||||
|
|
||||||
|
{
|
||||||
|
// Wait on stream 1
|
||||||
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
|
auto f = p->get_future().share();
|
||||||
|
scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); });
|
||||||
|
|
||||||
|
// Do some work on stream 2
|
||||||
|
auto x = zeros({100}, float32, s2);
|
||||||
|
auto y = ones({100}, float32, s2);
|
||||||
|
auto z = add(x, y, s2);
|
||||||
|
eval(z);
|
||||||
|
p->set_value();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Wait on stream 1
|
||||||
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
|
auto f = p->get_future().share();
|
||||||
|
scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); });
|
||||||
|
|
||||||
|
// Do some work on stream 2
|
||||||
|
auto fn = [&s2](array a) { return add(a, add(a, a, s2), s2); };
|
||||||
|
auto x = zeros({100}, s2);
|
||||||
|
|
||||||
|
// The whole vjp computation should happen
|
||||||
|
// on the second stream otherwise this will hang.
|
||||||
|
auto [out, dout] = vjp(fn, x, ones({100}, s2));
|
||||||
|
|
||||||
|
// The whole jvp computation should happen on the
|
||||||
|
// second stream.
|
||||||
|
std::tie(out, dout) = jvp(fn, x, ones({100}, s2));
|
||||||
|
eval(out, dout);
|
||||||
|
|
||||||
|
p->set_value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test scheduler races") {
|
||||||
|
auto x = zeros({1});
|
||||||
|
auto y = zeros({100});
|
||||||
|
eval(x, y);
|
||||||
|
auto a = exp(x);
|
||||||
|
eval(a);
|
||||||
|
a = exp(x);
|
||||||
|
for (int i = 0; i < 10000; ++i) {
|
||||||
|
y = exp(y);
|
||||||
|
}
|
||||||
|
eval(a, y);
|
||||||
|
}
|
26
tests/utils_tests.cpp
Normal file
26
tests/utils_tests.cpp
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test type promotion") {
|
||||||
|
for (auto t : {bool_, uint32, int32, int64, float32}) {
|
||||||
|
auto a = array(0, t);
|
||||||
|
CHECK_EQ(result_type({a}), t);
|
||||||
|
|
||||||
|
std::vector<array> arrs = {array(0, t), array(0, t)};
|
||||||
|
CHECK_EQ(result_type(arrs), t);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<array> arrs = {array(false), array(0, int32)};
|
||||||
|
CHECK_EQ(result_type(arrs), int32);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<array> arrs = {array(0, int32), array(false), array(0.0f)};
|
||||||
|
CHECK_EQ(result_type(arrs), float32);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user