diff --git a/README.md b/README.md index 8dd49919c..0a9d74524 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,61 @@ -# mlx -MLX: An array framework for Apple silicon +# MLX + +MLX is an array framework for machine learning specifically targeting Apple +Silicon. MLX is designed with inspiration from Jax, PyTorch, ArrayFire. + +[Documentation](https://at.apple.com/mlx) + +## Build + +``` +mkdir -p build && cd build +cmake .. && make -j +``` + +Run the C++ tests with `make test` (or `./tests/tests` for more detailed output). + +### Python bidings + +To install run: + +` +env CMAKE_BUILD_PARALLEL_LEVEL="" pip install . +` + +For developing use an editable install: + +``` +env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . +``` + +To make sure the install is working run the tests with: + +``` +python -m unittest discover python/tests +``` + + +## Develop + +- Fork and submit pull requests to the repo. + +- Every PR should have passing tests and at least one review. + +- If a change is likely to impact efficiency, run some of the benchmarks before + and after the change. Examples of benchmarks can be found in `benchmarks/cpp/`. + +- Install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. + This should install hooks for running `black` and `clang-format` to ensure + consistent style for C++ and python code. + + You can also run the formatters manually as follows: + + ``` + clang-format -i file.cpp + ``` + + ``` + black file.py + ``` + + or run `pre-commit run --all-files` to check all files in the repo. diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt new file mode 100644 index 000000000..82d5ffce9 --- /dev/null +++ b/benchmarks/cpp/CMakeLists.txt @@ -0,0 +1,11 @@ +function(build_benchmark SRCFILE) + get_filename_component(src_name ${SRCFILE} NAME_WE) + set(target "${src_name}") + add_executable(${target} ${SRCFILE}) + target_link_libraries(${target} PRIVATE mlx) +endfunction(build_benchmark) + +build_benchmark(single_ops.cpp) +build_benchmark(irregular_strides.cpp) +build_benchmark(compare_devices.cpp) +build_benchmark(autograd.cpp) diff --git a/benchmarks/cpp/autograd.cpp b/benchmarks/cpp/autograd.cpp new file mode 100644 index 000000000..fee3fc616 --- /dev/null +++ b/benchmarks/cpp/autograd.cpp @@ -0,0 +1,37 @@ +#include + +#include "mlx/mlx.h" +#include "time_utils.h" + +using namespace mlx::core; + +void time_value_and_grad() { + auto x = ones({200, 1000}); + eval(x); + auto fn = [](array x) { + for (int i = 0; i < 20; ++i) { + x = log(exp(x)); + } + return sum(x); + }; + + auto grad_fn = grad(fn); + auto independent_value_and_grad = [&]() { + auto value = fn(x); + auto dfdx = grad_fn(x); + return std::vector{value, dfdx}; + }; + TIME(independent_value_and_grad); + + auto value_and_grad_fn = value_and_grad(fn); + auto combined_value_and_grad = [&]() { + auto [value, dfdx] = value_and_grad_fn(x); + return std::vector{value, dfdx}; + }; + TIME(combined_value_and_grad); +} + +int main() { + std::cout << "Benchmarks for " << default_device() << std::endl; + time_value_and_grad(); +} diff --git a/benchmarks/cpp/compare_devices.cpp b/benchmarks/cpp/compare_devices.cpp new file mode 100644 index 000000000..9e7c696e6 --- /dev/null +++ b/benchmarks/cpp/compare_devices.cpp @@ -0,0 +1,25 @@ +#include +#include "mlx/mlx.h" +#include "time_utils.h" + +using namespace mlx::core; + +void time_add_op() { + std::vector sizes(1, 1); + for (int i = 0; i < 9; ++i) { + sizes.push_back(10 * sizes.back()); + } + set_default_device(Device::cpu); + for (auto size : sizes) { + auto a = random::uniform({size}); + auto b = random::uniform({size}); + eval(a, b); + std::cout << "Size " << size << std::endl; + TIMEM("cpu", add, a, b, Device::cpu); + TIMEM("gpu", add, a, b, Device::gpu); + } +} + +int main() { + time_add_op(); +} diff --git a/benchmarks/numpy/single_ops.py b/benchmarks/numpy/single_ops.py new file mode 100644 index 000000000..e359d3ec0 --- /dev/null +++ b/benchmarks/numpy/single_ops.py @@ -0,0 +1,38 @@ +import numpy as np + +from time_utils import time_fn + + +def time_add(): + a = np.ones((100, 100, 10), dtype=np.float32) + b = np.ones((100, 100, 10), dtype=np.float32) + time_fn(np.add, a, b) + + +def time_matmul(): + a = np.random.rand(1000, 500).astype(np.float32) + b = np.random.rand(500, 1000).astype(np.float32) + time_fn(np.matmul, a, b) + + +def time_exp(): + a = np.random.randn(1000, 100).astype(np.float32) + time_fn(np.exp, a) + + +def time_take(): + a = np.random.rand(10000, 500) + ids = np.random.randint(0, 10000, (20, 10)) + ids = [idx.reshape(-1) for idx in np.split(ids, 20)] + + def random_take(): + return [np.take(a, idx, 0) for idx in ids] + + time_fn(random_take) + + +if __name__ == "__main__": + time_add() + time_matmul() + time_exp() + time_take() diff --git a/benchmarks/numpy/time_utils.py b/benchmarks/numpy/time_utils.py new file mode 100644 index 000000000..3741fba1d --- /dev/null +++ b/benchmarks/numpy/time_utils.py @@ -0,0 +1,18 @@ +import time + + +def time_fn(fn, *args): + print(f"Timing {fn.__name__} ...", end=" ") + + # warmup + for _ in range(5): + fn(*args) + + num_iters = 100 + tic = time.perf_counter() + for _ in range(num_iters): + x = fn(*args) + toc = time.perf_counter() + + msec = 1e3 * (toc - tic) / num_iters + print(f"{msec:.5f} msec") diff --git a/benchmarks/python/blas/bench_gemm.py b/benchmarks/python/blas/bench_gemm.py new file mode 100644 index 000000000..27cd0793b --- /dev/null +++ b/benchmarks/python/blas/bench_gemm.py @@ -0,0 +1,190 @@ +import numpy as np +import argparse +import mlx.core as mx +import time +import torch +import os +import math +import subprocess + +device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) +device_name = device_name.decode("utf-8").strip("\n") + +N_warmup = 8 +N_iter_bench = 80 +N_iter_func = 5 + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + torch.mps.synchronize() + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def gemm_nn_mlx(a, b): + ys = [] + for i in range(N_iter_func): + y = a @ b + ys.append(y) + mx.eval(ys) + return ys + + +def gemm_nt_mlx(a, b): + ys = [] + for i in range(N_iter_func): + y = a @ b.transpose((0, 2, 1)) + ys.append(y) + mx.eval(ys) + return ys + + +def gemm_tn_mlx(a, b): + ys = [] + for i in range(N_iter_func): + y = a.transpose((0, 2, 1)) @ b + ys.append(y) + mx.eval(ys) + return ys + + +def gemm_tt_mlx(a, b): + ys = [] + for i in range(N_iter_func): + y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1)) + ys.append(y) + mx.eval(ys) + return ys + + +@torch.no_grad() +def gemm_nn_torch(a, b): + ys = [] + for i in range(N_iter_func): + y = a @ b + ys.append(y) + torch.mps.synchronize() + return ys + + +@torch.no_grad() +def gemm_nt_torch(a, b): + ys = [] + for i in range(N_iter_func): + y = a @ b.transpose(-1, -2) + ys.append(y) + torch.mps.synchronize() + return ys + + +@torch.no_grad() +def gemm_tn_torch(a, b): + ys = [] + for i in range(N_iter_func): + y = a.transpose(-1, -2) @ b + ys.append(y) + torch.mps.synchronize() + return ys + + +@torch.no_grad() +def gemm_tt_torch(a, b): + ys = [] + for i in range(N_iter_func): + y = a.transpose(-1, -2) @ b.transpose(-1, -2) + ys.append(y) + torch.mps.synchronize() + return ys + + +def bench_shape(B, M, N, K, np_dtype, transpose="nn"): + shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M) + shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K) + + a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype) + b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np).to("mps") + b_pt = torch.from_numpy(b_np).to("mps") + + torch.mps.synchronize() + + f_mx = { + "nn": gemm_nn_mlx, + "nt": gemm_nt_mlx, + "tn": gemm_tn_mlx, + "tt": gemm_tt_mlx, + }[transpose] + + f_pt = { + "nn": gemm_nn_torch, + "nt": gemm_nt_torch, + "tn": gemm_tn_torch, + "tt": gemm_tt_torch, + }[transpose] + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1) + t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) + + c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) + c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype( + np.float32 + ) + + atol = 1e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol): + print( + f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}" + ) + + return time_mlx, time_torch + + +def get_gflop_count(B, M, N, K): + return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run gemm benchmarks") + + dtypes = ("float32", "float16") + transposes = ("nn", "nt", "tn") + shapes = ( + (16, 1024, 1024, 1024), + (1, 1024, 1024, 2048), + (4, 1024, 1024, 4096), + (4, 1024, 4096, 1024), + (1, 4096, 4096, 4096), + (15, 1023, 1023, 1023), + (17, 1025, 1025, 1025), + ) + + for dtype in dtypes: + for transpose in transposes: + for B, M, N, K in shapes: + np_dtype = getattr(np, dtype) + time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose) + + gflop_count = get_gflop_count(B, M, N, K) + gflops_mx = gflop_count / (time_mlx) + gflops_pt = gflop_count / (time_torch) + diff = gflops_mx / gflops_pt - 1.0 + + print( + f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%" + ) + if gflops_pt >= 2.0 * gflops_mx: + print("ATTENTION ^^^^^^^") diff --git a/benchmarks/python/blas/bench_gemv.py b/benchmarks/python/blas/bench_gemv.py new file mode 100644 index 000000000..e95e48fa4 --- /dev/null +++ b/benchmarks/python/blas/bench_gemv.py @@ -0,0 +1,219 @@ +import matplotlib.pyplot as plt +import numpy as np +import argparse +import mlx.core as mx +import time +import torch +import os +import subprocess + + +results_dir = "./results" + +if not os.path.isdir(results_dir): + os.mkdir(results_dir) + +device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) +device_name = device_name.decode("utf-8").strip("\n") + +N_warmup = 5 +N_iter_bench = 50 +N_iter_func = 20 + +out_vec_sizes = [128, 512, 2048, 4096] +in_vec_sizes = [128, 512, 2048, 4096] + +benchmark_vector_lens = [] +benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2] +benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2] +benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2] +benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000] + +benchmark_vector_lens.sort() + + +def bench(f, m, v): + for i in range(N_warmup): + f(m, v) + torch.mps.synchronize() + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(m, v) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def gemv_mlx(m, v): + ys = [] + for i in range(N_iter_func): + y = m @ v + ys.append(y) + mx.eval(ys) + return ys + + +def gemv_t_mlx(m, v): + ys = [] + for i in range(N_iter_func): + y = v @ m + ys.append(y) + mx.eval(ys) + return ys + + +@torch.no_grad() +def gemv_torch(m, v): + ys = [] + for i in range(N_iter_func): + y = m @ v + ys.append(y) + torch.mps.synchronize() + return ys + + +@torch.no_grad() +def gemv_t_torch(m, v): + ys = [] + for i in range(N_iter_func): + y = v @ m + ys.append(y) + torch.mps.synchronize() + return ys + + +def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False): + shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len) + shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1) + + mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype) + vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype) + mat_mlx = mx.array(mat_npy) + vec_mlx = mx.array(vec_npy) + mat_trc = torch.from_numpy(mat_npy).to("mps") + vec_trc = torch.from_numpy(vec_npy).to("mps") + + torch.mps.synchronize() + + time_torch = ( + bench(gemv_t_torch, mat_trc, vec_trc) + if transpose + else bench(gemv_torch, mat_trc, vec_trc) + ) + time_mlx = ( + bench(gemv_t_mlx, mat_mlx, vec_mlx) + if transpose + else bench(gemv_mlx, mat_mlx, vec_mlx) + ) + + c_mlx = ( + np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx) + ) + c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy) + + if not np.allclose(c_mlx, c_npy, atol=2e-5): + print( + f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}" + ) + + return time_mlx, time_torch + + +def get_gflop_count(in_vec_len, out_vec_len): + return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float( + 1024**3 + ) + + +def get_gbyte_size(in_vec_len, out_vec_len, np_dtype): + n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len + item_size = 4 if np_dtype == np.float32 else 2 + return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3) + + +def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose): + np_dtype = getattr(np, dtype) + mlx_gb_s = [] + mlx_gflops = [] + pyt_gb_s = [] + pyt_gflops = [] + + for out_vec_len in out_vector_lens: + gflop_count = get_gflop_count(in_vec_len, out_vec_len) + gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype) + + time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose) + + mlx_gb_s.append(gbyte_size / time_mlx) + pyt_gb_s.append(gbyte_size / time_torch) + + mlx_gflops.append(gflop_count / time_mlx) + pyt_gflops.append(gflop_count / time_torch) + + if transpose: + title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}" + else: + title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}" + + ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX") + ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch") + ax.set_title(title) + ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)") + ax.legend() + + +def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose): + np_dtype = getattr(np, dtype) + mlx_gb_s = [] + mlx_gflops = [] + pyt_gb_s = [] + pyt_gflops = [] + + for in_vec_len in in_vector_lens: + gflop_count = get_gflop_count(in_vec_len, out_vec_len) + gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype) + + time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose) + + mlx_gb_s.append(gbyte_size / time_mlx) + pyt_gb_s.append(gbyte_size / time_torch) + + mlx_gflops.append(gflop_count / time_mlx) + pyt_gflops.append(gflop_count / time_torch) + + if transpose: + title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])" + else: + title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )" + + ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX") + ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch") + ax.set_title(title) + ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)") + ax.legend() + + +for transpose in (False, True): + for dtype in ("float32", "float16"): + fig, axs = plt.subplots( + len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" + ) + + for i, in_vec_len in enumerate(in_vec_sizes): + bench_with_in_len( + axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose + ) + + for i, out_vec_len in enumerate(out_vec_sizes): + bench_with_out_len( + axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose + ) + + op_name = "gemv_t" if transpose else "gemv" + fig.suptitle(f"{device_name}: {dtype} {op_name}") + fig.savefig( + os.path.join( + results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf' + ) + ) + plt.close(fig) diff --git a/benchmarks/python/llama_mlx_bench.py b/benchmarks/python/llama_mlx_bench.py new file mode 100644 index 000000000..f1cea4735 --- /dev/null +++ b/benchmarks/python/llama_mlx_bench.py @@ -0,0 +1,116 @@ +import math +import time + +import mlx.core as mx +import mlx.nn as nn +import mlx.utils + + +class LlamaAttention(nn.Module): + def __init__(self, dims: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.rope = nn.RoPE(dims // num_heads, True) + self.query_proj = nn.Linear(dims, dims, False) + self.key_proj = nn.Linear(dims, dims, False) + self.value_proj = nn.Linear(dims, dims, False) + self.out_proj = nn.Linear(dims, dims, False) + + def __call__(self, queries, keys, values, mask=None, cache=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3)) + keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3)) + values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # Dimensions are [batch x num heads x sequence x hidden dim] + scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype) + scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2)) + if mask is not None: + scores = scores + mask + scores = mx.softmax(scores, axis=-1) + values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1)) + + return self.out_proj(values_hat), (keys, values) + + +class LlamaEncoderLayer(nn.Module): + def __init__(self, dims: int, mlp_dims: int, num_heads: int): + super().__init__() + + self.attention = LlamaAttention(dims, num_heads) + + self.norm1 = nn.RMSNorm(dims) + self.norm2 = nn.RMSNorm(dims) + + self.linear1 = nn.Linear(dims, mlp_dims, False) + self.linear2 = nn.Linear(dims, mlp_dims, False) + self.linear3 = nn.Linear(mlp_dims, dims, False) + + def __call__(self, x, mask=None, cache=None): + y = self.norm1(x) + y, cache = self.attention(y, y, y, mask, cache) + x = x + y + + y = self.norm2(x) + a = self.linear1(y) + b = self.linear2(y) + y = a * mx.sigmoid(a) * b + y = self.linear3(y) + x = x + y + + return x, cache + + +def measure(model, x, cache): + for i in range(5): + y, c = model(x, mask=None, cache=cache) + mx.eval(y, c) + + start = time.time() + rs = [] + for i in range(5): + y, c = model(x, mask=None, cache=cache) + rs.append((y, c)) + mx.eval(rs) + end = time.time() + + return (end - start) * 1000 / 5 + + +if __name__ == "__main__": + H = 32 + D = 4096 + F = 43 * 256 + C = 1000 + mx.set_default_device(mx.gpu) + dtype = mx.float16 + + layer = LlamaEncoderLayer(D, F, H) + layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters())) + k1, k2, k3 = mx.random.split(mx.random.key(0), 3) + x = mx.random.normal([1, 1, D], dtype=dtype) + cache = [ + mx.random.normal([1, H, C, D // H], dtype=dtype), + mx.random.normal([1, H, C, D // H], dtype=dtype), + ] + mx.eval(x, cache) + + T = measure(layer, x, cache) + + print("Time per layer per token:", T, "ms") + print("Lower bound total time per token:", T * 32, "ms") diff --git a/cmake/extension.cmake b/cmake/extension.cmake new file mode 100644 index 000000000..383656d37 --- /dev/null +++ b/cmake/extension.cmake @@ -0,0 +1,56 @@ +include(CMakeParseArguments) + +############################################################################### +# Build metal library +# +# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib +# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS} +# +# Args: +# TARGET: Custom target to be added for the metal library +# TITLE: Name of the .metallib +# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib +# SOURCES: List of source files +# INCLUDE_DIRS: List of include dirs +# DEPS: List of depedency files (like headers) +# +macro(mlx_build_metallib) + # Parse args + set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) + set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) + cmake_parse_arguments( + MTLLIB + "" + "${oneValueArgs}" + "${multiValueArgs}" + ${ARGN} + ) + + # Set output + set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") + + # Collect compile options + set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math) + + # Prepare metllib build command + add_custom_command( + OUTPUT ${MTLLIB_BUILD_TARGET} + COMMAND xcrun -sdk macosx metal + "$" + ${MTLLIB_COMPILE_OPTIONS} + ${MTLLIB_SOURCES} + -o ${MTLLIB_BUILD_TARGET} + DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES} + COMMAND_EXPAND_LISTS + COMMENT "Building ${MTLLIB_TITLE}.metallib" + VERBATIM + ) + + # Add metallib custom target + add_custom_target( + ${MTLLIB_TARGET} + DEPENDS + ${MTLLIB_BUILD_TARGET} + ) + +endmacro(mlx_build_metallib) \ No newline at end of file diff --git a/docs/.nojekyll b/docs/.nojekyll new file mode 100644 index 000000000..e69de29bb diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 000000000..070710893 --- /dev/null +++ b/docs/index.html @@ -0,0 +1 @@ + diff --git a/docs/src/_templates/nn-module-template.rst b/docs/src/_templates/nn-module-template.rst new file mode 100644 index 000000000..49f018eb5 --- /dev/null +++ b/docs/src/_templates/nn-module-template.rst @@ -0,0 +1,19 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + + {#{% block methods %} + + {% if methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + {% for item in methods %} + {%- if item not in inherited_members and item != '__init__' %} + ~{{ name }}.{{ item }} + {%- endif %} + {%- endfor %} + {% endif %} + {% endblock %}#} diff --git a/docs/src/cpp/ops.rst b/docs/src/cpp/ops.rst new file mode 100644 index 000000000..4d2d1404e --- /dev/null +++ b/docs/src/cpp/ops.rst @@ -0,0 +1,6 @@ +.. _cpp_ops: + +Operations +========== + + diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst new file mode 100644 index 000000000..36f44a4bb --- /dev/null +++ b/docs/src/dev/extensions.rst @@ -0,0 +1,948 @@ +Developer Documentation +======================= + +MLX provides a open and flexible backend to which users may add operations +and specialized implementations without much hassle. While the library supplies +efficient operations that can be used and composed for any number of +applications, there may arise cases where new functionalities or highly +optimized implementations are needed. For such cases, you may design and +implement your own operations that link to and build on top of :mod:`mlx.core`. +We will introduce the inner-workings of MLX and go over a simple example to +learn the steps involved in adding new operations to MLX with your own CPU +and GPU implementations. + +Introducing the Example +----------------------- + +Let's say that you would like an operation that takes in two arrays, +``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta`` +respectively, and then adds them together to get the result +``z = alpha * x + beta * y``. Well, you can very easily do that by just +writing out a function as follows: + +.. code-block:: python + + import mlx.core as mx + + def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: + return alpha * x + beta * y + +This function performs that operation while leaving the implementations and +differentiation to MLX. + +However, you work with vector math libraries often and realize that the +``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``. +You would really like the part of your applications that does this operation +on the CPU to be very fast - so you decide that you want it to rely on the +``axpby`` routine provided by the Accelerate_ framework. Continuing to impose +our assumptions on to you, let's also assume that you want to learn how add +your own implementation for the gradients of your new operation while going +over the ins-and-outs of the MLX framework. + +Well, what a coincidence! You are in the right place. Over the course of this +example, we will learn: + +* The structure of the MLX library from the frontend API to the backend implementations. +* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed). +* How to implement your own GPU implementation using metal. +* How to add your own ``vjp`` and ``jvp``. +* How to build your implementations, link them to MLX, and bind them to python. + +Operations and Primitives +------------------------- + +In one sentence, operations in MLX build the computation graph, and primitives +provide the rules for evaluation and transformations of said graph. Let's start +by discussing operations in more detail. + +Operations +^^^^^^^^^^^ + +Operations are the frontend functions that operate on arrays. They are defined +in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these +operations in the Python API (:ref:`ops`). + +We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``, +and two scalars, ``alpha`` and ``beta``. This is how we would define it in the +C++ API: + +.. code-block:: C++ + + /** + * Scale and sum two vectors elementwise + * z = alpha * x + beta * y + * + * Follow numpy style broadcasting between x and y + * Inputs are upcasted to floats if needed + **/ + array axpby( + const array& x, // Input array x + const array& y, // Input array y + const float alpha, // Scaling factor for x + const float beta, // Scaling factor for y + StreamOrDevice s = {} // Stream on which to schedule the operation + ); + + +This operation itself can call other operations within it if needed. So, the +simplest way to go about implementing this operation would be do so in terms +of existing operations. + +.. code-block:: C++ + + array axpby( + const array& x, // Input array x + const array& y, // Input array y + const float alpha, // Scaling factor for x + const float beta, // Scaling factor for y + StreamOrDevice s /* = {} */ // Stream on which to schedule the operation + ) { + // Scale x and y on the provided stream + auto ax = multiply(array(alpha), x, s); + auto by = multiply(array(beta), y, s); + + // Add and return + return add(ax, by, s); + } + +However, as we discussed earlier, this is not our goal. The operations themselves +do not contain the implementations that act on the data, nor do they contain the +rules of transformations. Rather, they are an easy to use interface that build +on top of the building blocks we call :class:`Primitive`. + +Primitives +^^^^^^^^^^^ + +A :class:`Primitive` is part of the computation graph of an :class:`array`. It +defines how to create an output given a set of input :class:`array` . Further, +a :class:`Primitive` is a class that contains rules on how it is evaluated +on the CPU or GPU, and how it acts under transformations such as ``vjp`` and +``jvp``. These words on their own can be a bit abstract, so lets take a step +back and go to our example to give ourselves a more concrete image. + +.. code-block:: C++ + + class Axpby : public Primitive { + public: + explicit Axpby(Stream stream, float alpha, float beta) + : Primitive(stream), alpha_(alpha), beta_(beta){}; + + /** + * A primitive must know how to evaluate itself on the CPU/GPU + * for the given inputs and populate the output array. + * + * To avoid unecessary allocations, the evaluation function + * is responsible for allocating space for the array. + */ + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + /** The Jacobian-vector product. */ + array jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + + /** The vector-Jacobian product. */ + std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) override; + + /** + * The primitive must know how to vectorize itself accross + * the given axes. The output is a pair containing the array + * representing the vectorized computation and the axis which + * corresponds to the output vectorized dimension. + */ + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + /** Print the primitive. */ + void print(std::ostream& os) override { + os << "Axpby"; + } + + /** Equivalence check **/ + bool is_equivalent(const Primitive& other) const override; + + private: + float alpha_; + float beta_; + + /** Fall back implementation for evaluation on CPU */ + void eval(const std::vector& inputs, array& out); + }; + +The :class:`Axpby` class derives from the base :class:`Primitive` class and +follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and +``beta`` as parameters. It then provides implementations of how the array ``out`` +is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and +:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in +:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`. + +Using the Primitives +^^^^^^^^^^^^^^^^^^^^^ + +Operations can use this :class:`Primitive` to add a new :class:`array` to +the computation graph. An :class:`array` can be constructed by providing its +data type, shape, the :class:`Primitive` that computes it, and the +:class:`array` inputs that are passed to the primitive. + +Let's re-implement our operation now in terms of our :class:`Axpby` primitive. + +.. code-block:: C++ + + array axpby( + const array& x, // Input array x + const array& y, // Input array y + const float alpha, // Scaling factor for x + const float beta, // Scaling factor for y + StreamOrDevice s /* = {} */ // Stream on which to schedule the operation + ) { + // Promote dtypes between x and y as needed + auto promoted_dtype = promote_types(x.dtype(), y.dtype()); + + // Upcast to float32 for non-floating point inputs x and y + auto out_dtype = is_floating_point(promoted_dtype) + ? promoted_dtype + : promote_types(promoted_dtype, float32); + + // Cast x and y up to the determined dtype (on the same stream s) + auto x_casted = astype(x, out_dtype, s); + auto y_casted = astype(y, out_dtype, s); + + // Broadcast the shapes of x and y (on the same stream s) + auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s); + auto out_shape = broadcasted_inputs[0].shape(); + + // Construct the array as the output of the Axpby primitive + // with the broadcasted and upcasted arrays as inputs + return array( + /* const std::vector& shape = */ out_shape, + /* Dtype dtype = */ out_dtype, + /* std::unique_ptr primitive = */ + std::make_unique(to_stream(s), alpha, beta), + /* const std::vector& inputs = */ broadcasted_inputs); + } + + +This operation now handles the following: + +#. Upcast inputs and resolve the the output data type. +#. Broadcast the inputs and resolve the output shape. +#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``. +#. Construct the output :class:`array` using the primitive and the inputs. + +Implementing the Primitive +-------------------------- + +No computation happens when we call the operation alone. In effect, the +operation only builds the computation graph. When we evaluate the output +array, MLX schedules the execution of the computation graph, and calls +:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the +stream/device specified by the user. + +.. warning:: + When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called, + no memory has been allocated for the output array. It falls on the implementation + of these functions to allocate memory as needed + +Implementing the CPU Backend +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Let's start by trying to implement a naive and generic version of +:meth:`Axpby::eval_cpu`. We declared this as a private member function of +:class:`Axpby` earlier called :meth:`Axpby::eval`. + +Our naive method will go over each element of the output array, find the +corresponding input elements of ``x`` and ``y`` and perform the operation +pointwise. This is captured in the templated function :meth:`axpby_impl`. + +.. code-block:: C++ + + template + void axpby_impl( + const array& x, + const array& y, + array& out, + float alpha_, + float beta_) { + // We only allocate memory when we are ready to fill the output + // malloc_or_wait synchronously allocates available memory + // There may be a wait executed here if the allocation is requested + // under memory-pressured conditions + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + // Collect input and output data pointers + const T* x_ptr = x.data(); + const T* y_ptr = y.data(); + T* out_ptr = out.data(); + + // Cast alpha and beta to the relevant types + T alpha = static_cast(alpha_); + T beta = static_cast(beta_); + + // Do the elementwise operation for each output + for (size_t out_idx = 0; out_idx < out.size(); out_idx++) { + // Map linear indices to offsets in x and y + auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides()); + auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides()); + + // We allocate the output to be contiguous and regularly strided + // (defaults to row major) and hence it doesn't need additonal mapping + out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; + } + } + +Now, we would like our implementation to be able to do this pointwise operation +for all incoming floating point arrays. Accordingly, we add dispatches for +``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error +if we encounter an unexpected type. + +.. code-block:: C++ + + /** Fall back implementation for evaluation on CPU */ + void Axpby::eval(const std::vector& inputs, array& out) { + // Check the inputs (registered in the op while contructing the out array) + assert(inputs.size() == 2); + auto& x = inputs[0]; + auto& y = inputs[1]; + + // Dispatch to the correct dtype + if (out.dtype() == float32) { + return axpby_impl(x, y, out, alpha_, beta_); + } else if (out.dtype() == float16) { + return axpby_impl(x, y, out, alpha_, beta_); + } else if (out.dtype() == bfloat16) { + return axpby_impl(x, y, out, alpha_, beta_); + } else if (out.dtype() == complex64) { + return axpby_impl(x, y, out, alpha_, beta_); + } else { + throw std::runtime_error( + "Axpby is only supported for floating point types."); + } + } + +We have a fallback implementation! Now, to do what we are really here to do. +Remember we wanted to use the ``axpby`` routine provided by the Accelerate_ +framework? Well, there are 3 complications to keep in mind: + +#. Accelerate does not provide implementations of ``axpby`` for half precision + floats. We can only direct to it for ``float32`` types +#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements + have fixed strides between them. Possibly due to broadcasts and transposes, + we aren't guaranteed that the inputs fit this requirement. We can + only direct to Accelerate if both ``x`` and ``y`` are row contiguous or + column contiguous. +#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace. + MLX expects to write out the answer to a new array. We must copy the elements + of ``y`` into the output array and use that as an input to ``axpby`` + +Let's write out an implementation that uses Accelerate in the right conditions. +It must simply allocate data for the output, copy elements of ``y`` into it, +and then call the :meth:`catlas_saxpby` from accelerate. + +.. code-block:: C++ + + template + void axpby_impl_accelerate( + const array& x, + const array& y, + array& out, + float alpha_, + float beta_) { + // Accelerate library provides catlas_saxpby which does + // Y = (alpha * X) + (beta * Y) in place + // To use it, we first copy the data in y over to the output array + + // This specialization requires both x and y be contiguous in the same mode + // i.e: corresponding linear indices in both point to corresponding elements + // The data in the output array is allocated to match the strides in y + // such that x, y, and out are contiguous in the same mode and + // no transposition is needed + out.set_data( + allocator::malloc_or_wait(y.data_size() * out.itemsize()), + y.data_size(), + y.strides(), + y.flags()); + + // We then copy over the elements using the contiguous vector specialization + copy_inplace(y, out, CopyType::Vector); + + // Get x and y pointers for catlas_saxpby + const T* x_ptr = x.data(); + T* y_ptr = out.data(); + + T alpha = static_cast(alpha_); + T beta = static_cast(beta_); + + // Call the inplace accelerate operator + catlas_saxpby( + /* N = */ out.size(), + /* ALPHA = */ alpha, + /* X = */ x_ptr, + /* INCX = */ 1, + /* BETA = */ beta, + /* Y = */ y_ptr, + /* INCY = */ 1); + } + +Great! But what about the inputs that do not fit the criteria for accelerate? +Luckily, we can always just direct back to :meth:`Axpby::eval`. + +With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`. + +.. code-block:: C++ + + /** Evaluate primitive on CPU using accelerate specializations */ + void Axpby::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& x = inputs[0]; + auto& y = inputs[1]; + + // Accelerate specialization for contiguous single precision float arrays + if (out.dtype() == float32 && + ((x.flags().row_contiguous && y.flags().row_contiguous) || + (x.flags().col_contiguous && y.flags().col_contiguous))) { + axpby_impl_accelerate(x, y, out, alpha_, beta_); + return; + } + + // Fall back to common backend if specializations are not available + eval(inputs, out); + } + +We have now hit a milestone! Just this much is enough to run the operation +:meth:`axpby` on a CPU stream! + +If you do not plan on running the operation on the GPU or using transforms on +computation graphs that contain :class:`Axpby`, you can stop implementing the +primitive here and enjoy the speed-ups you get from the Accelerate library. + +Implementing the GPU Backend +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Apple silicon devices address their GPUs using the Metal_ shading language, and +all GPU kernels in MLX are written using metal. + +.. note:: + + Here are some helpful resources if you are new to metal! + + * A walkthrough of the metal compute pipeline: `Metal Example`_ + * Documentation for metal shading language: `Metal Specification`_ + * Using metal from C++: `Metal-cpp`_ + +Let's keep the GPU algorithm simple. We will launch exactly as many threads +as there are elements in the output. Each thread will pick the element it needs +from ``x`` and ``y``, do the pointwise operation, and then update its assigned +element in the output. + +.. code-block:: C++ + + template + [[kernel]] void axpby_general( + device const T* x [[buffer(0)]], + device const T* y [[buffer(1)]], + device T* out [[buffer(2)]], + constant const float& alpha [[buffer(3)]], + constant const float& beta [[buffer(4)]], + constant const int* shape [[buffer(5)]], + constant const size_t* x_strides [[buffer(6)]], + constant const size_t* y_strides [[buffer(7)]], + constant const int& ndim [[buffer(8)]], + uint index [[thread_position_in_grid]]) { + // Convert linear indices to offsets in array + auto x_offset = elem_to_loc(index, shape, x_strides, ndim); + auto y_offset = elem_to_loc(index, shape, y_strides, ndim); + + // Do the operation and update the output + out[index] = + static_cast(alpha) * x[x_offset] + static_cast(beta) * y[y_offset]; + } + +We then need to instantiate this template for all floating point types and give +each instantiation a unique host name so we can identify the right kernel for +each data type. + +.. code-block:: C++ + + #define instantiate_axpby(type_name, type) \ + template [[host_name("axpby_general_" #type_name)]] \ + [[kernel]] void axpby_general( \ + device const type* x [[buffer(0)]], \ + device const type* y [[buffer(1)]], \ + device type* out [[buffer(2)]], \ + constant const float& alpha [[buffer(3)]], \ + constant const float& beta [[buffer(4)]], \ + constant const int* shape [[buffer(5)]], \ + constant const size_t* x_strides [[buffer(6)]], \ + constant const size_t* y_strides [[buffer(7)]], \ + constant const int& ndim [[buffer(8)]], \ + uint index [[thread_position_in_grid]]); + + instantiate_axpby(float32, float); + instantiate_axpby(float16, half); + instantiate_axpby(bflot16, bfloat16_t); + instantiate_axpby(complex64, complex64_t); + +This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we +will see later in :ref:`Building with CMake`. In the following example, we +assume that the library ``mlx_ext.metallib`` will always be co-located with +the executable/ shared-library calling the :meth:`register_library` function. +The :meth:`register_library` function takes the library's name and potential +path (or in this case, a function that can produce the path of the metal +library) and tries to load that library if it hasn't already been registered +by the relevant static :class:`mlx::core::metal::Device` object. This is why, +it is important to package your C++ library with the metal library. We will +go over this process in more detail later. + +The logic to determine the kernel, set the inputs, resolve the grid dimensions +and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown +below. + +.. code-block:: C++ + + /** Evaluate primitive on GPU */ + void Axpby::eval_gpu(const std::vector& inputs, array& out) { + // Prepare inputs + assert(inputs.size() == 2); + auto& x = inputs[0]; + auto& y = inputs[1]; + + // Each primitive carries the stream it should execute on + // and each stream carries its device identifiers + auto& s = stream(); + // We get the needed metal device using the stream + auto& d = metal::device(s.device); + + // Allocate output memory + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + // Resolve name of kernel (corresponds to axpby.metal) + std::ostringstream kname; + kname << "axpby_" << "general_" << type_to_name(out); + + // Make sure the metal library is available and look for it + // in the same folder as this executable if needed + d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + + // Make a kernel from this metal library + auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + + // Prepare to encode kernel + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + // Kernel parameters are registered with buffer indices corresponding to + // those in the kernel decelaration at axpby.metal + int ndim = out.ndim(); + size_t nelem = out.size(); + + // Encode input arrays to kernel + set_array_buffer(compute_encoder, x, 0); + set_array_buffer(compute_encoder, y, 1); + + // Encode output arrays to kernel + set_array_buffer(compute_encoder, out, 2); + + // Encode alpha and beta + compute_encoder->setBytes(&alpha_, sizeof(float), 3); + compute_encoder->setBytes(&beta_, sizeof(float), 4); + + // Encode shape, strides and ndim + compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); + compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); + compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); + compute_encoder->setBytes(&ndim, sizeof(int), 8); + + // We launch 1 thread for each input and make sure that the number of + // threads in any given threadgroup is not higher than the max allowed + size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup()); + + // Fix the 3D size of each threadgroup (in terms of threads) + MTL::Size group_dims = MTL::Size(tgp_size, 1, 1); + + // Fix the 3D size of the launch grid (in terms of threads) + MTL::Size grid_dims = MTL::Size(nelem, 1, 1); + + // Launch the grid with the given number of threads divded among + // the given threadgroups + compute_encoder->dispatchThreads(grid_dims, group_dims); + } + +We can now call the :meth:`axpby` operation on both the CPU and the GPU! + +A few things to note about MLX and metal before moving on. MLX keeps track +of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder` +to give us the active metal compute command encoder instead of building a +new one and calling :meth:`compute_encoder->end_encoding` at the end. +MLX keeps adding kernels (compute pipelines) to the active command encoder +until some specified limit is hit or the compute encoder needs to be flushed +for synchronization. MLX also handles enqueuing and commiting the associated +command buffers as needed. We suggest taking a deeper dive into +:class:`metal::Device` if you would like to study this routine further. + +Primitive Transforms +^^^^^^^^^^^^^^^^^^^^^ + +Now that we have come this far, let's also learn how to add implementations to +transformations in a :class:`Primitive`. These transformations can be built on +top of our operations, including the one we just defined now. Which then gives +us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations. + +.. code-block:: C++ + + /** The Jacobian-vector product. */ + array Axpby::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + // Forward mode diff that pushes along the tangents + // The jvp transform on the the primitive can built with ops + // that are scheduled on the same stream as the primtive + + // If argnums = {0}, we only push along x in which case the + // jvp is just the tangent scaled by alpha + // Similarly, if argnums = {1}, the jvp is just the tangent + // scaled by beta + if (argnums.size() > 1) { + auto scale = argnums[0] == 0 ? alpha_ : beta_; + auto scale_arr = array(scale, tangents[0].dtype()); + return multiply(scale_arr, tangents[0], stream()); + } + // If, argnums = {0, 1}, we take contributions from both + // which gives us jvp = tangent_x * alpha + tangent_y * beta + else { + return axpby(tangents[0], tangents[1], alpha_, beta_, stream()); + } + } + +.. code-block:: C++ + + /** The vector-Jacobian product. */ + std::vector Axpby::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + // Reverse mode diff + std::vector vjps; + for (auto arg : argnums) { + auto scale = arg == 0 ? alpha_ : beta_; + auto scale_arr = array(scale, cotan.dtype()); + vjps.push_back(multiply(scale_arr, cotan, stream())); + } + return vjps; + } + +Finally, you need not have a transformation fully defined to start using your +own :class:`Primitive`. + +.. code-block:: C++ + + /** Vectorize primitve along given axis */ + std::pair Axpby::vmap( + const std::vector& inputs, + const std::vector& axes) { + throw std::runtime_error("Axpby has no vmap implementation."); + } + +Building and Binding +-------------------- + +Let's look at the overall directory structure first. + +| extensions +| ├── axpby +| │ ├── axpby.cpp +| │ ├── axpby.h +| │ └── axpby.metal +| ├── mlx_sample_extensions +| │ └── __init__.py +| ├── bindings.cpp +| ├── CMakeLists.txt +| └── setup.py + +* ``extensions/axpby/`` defines the C++ extension library +* ``extensions/mlx_sample_extensions`` sets out the strucutre for the + associated python package +* ``extensions/bindings.cpp`` provides python bindings for our operation +* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and + python bindings +* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install + the python package + +Binding to Python +^^^^^^^^^^^^^^^^^^ + +We use PyBind11_ to build a Python API for the C++ library. Since bindings +for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc. +are already provided, adding our :meth:`axpby` becomes very simple! + +.. code-block:: C++ + + PYBIND11_MODULE(mlx_sample_extensions, m) { + m.doc() = "Sample C++ and metal extensions for MLX"; + + m.def( + "axpby", + &axpby, + "x"_a, + "y"_a, + py::pos_only(), + "alpha"_a, + "beta"_a, + py::kw_only(), + "stream"_a = py::none(), + R"pbdoc( + Scale and sum two vectors elementwise + ``z = alpha * x + beta * y`` + + Follows numpy style broadcasting between ``x`` and ``y`` + Inputs are upcasted to floats if needed + + Args: + x (array): Input array. + y (array): Input array. + alpha (float): Scaling factor for ``x``. + beta (float): Scaling factor for ``y``. + + Returns: + array: ``alpha * x + beta * y`` + )pbdoc"); + } + +Most of the complexity in the above example comes from additional bells and +whistles such as the literal names and doc-strings. + +.. warning:: + + :mod:`mlx.core` needs to be imported before importing + :mod:`mlx_sample_extensions` as defined by the pybind11 module above to + ensure that the casters for :mod:`mlx.core` components like + :class:`mlx.core.array` are available. + +.. _Building with CMake: + +Building with CMake +^^^^^^^^^^^^^^^^^^^^ + +Building the C++ extension library itself is simple, it only requires that you +``find_package(MLX CONFIG)`` and then link it to your library. + +.. code-block:: cmake + + # Add library + add_library(mlx_ext) + + # Add sources + target_sources( + mlx_ext + PUBLIC + ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp + ) + + # Add include headers + target_include_directories( + mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR} + ) + + # Link to mlx + target_link_libraries(mlx_ext PUBLIC mlx) + +We also need to build the attached metal library. For convenience, we provide a +:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given +sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and +automatically imported with MLX package). + +Here is what that looks like in practice! + +.. code-block:: cmake + + # Build metallib + if(MLX_BUILD_METAL) + + mlx_build_metallib( + TARGET mlx_ext_metallib + TITLE mlx_ext + SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal + INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} + OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY} + ) + + add_dependencies( + mlx_ext + mlx_ext_metallib + ) + + endif() + +Finally, we build the Pybind11_ bindings + +.. code-block:: cmake + + pybind11_add_module( + mlx_sample_extensions + ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp + ) + target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) + + if(BUILD_SHARED_LIBS) + target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) + endif() + +Building with ``setuptools`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Once we have set out the CMake build rules as described above, we can use the +build utilities defined in :mod:`mlx.extension` for a simple build process. + +.. code-block:: python + + from mlx import extension + from setuptools import setup + + if __name__ == "__main__": + setup( + name="mlx_sample_extensions", + version="0.0.0", + description="Sample C++ and Metal extensions for MLX primitives.", + ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], + cmdclass={"build_ext": extension.CMakeBuild}, + packages = ["mlx_sample_extensions"], + package_dir = {"": "mlx_sample_extensions"}, + package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]}, + zip_safe=False, + python_requires=">=3.7", + ) + +.. note:: + We treat ``extensions/mlx_sample_extensions`` as the package directory + even though it only contains a ``__init__.py`` to ensure the following: + + * :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions` + * The C++ extension library and the metal library are co-located with the python + bindings and copied together if the package is installed + +You can build inplace for development using +``python setup.py build_ext -j8 --inplace`` (in ``extensions/``) + +This will result in a directory structure as follows: + +| extensions +| ├── mlx_sample_extensions +| │ ├── __init__.py +| │ ├── libmlx_ext.dylib # C++ extension library +| │ ├── mlx_ext.metallib # Metal library +| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding +| ... + +When you try to install using the command ``python -m pip install .`` +(in ``extensions/``), the package will be installed with the same strucutre as +``extensions/mlx_sample_extensions`` and the C++ and metal library will be +copied along with the python binding since they are specified as ``package_data``. + +Usage +----- + +After installing the extension as described above, you should be able to simply +import the python package and play with it as you would any other MLX operation! + +Let's looks at a simple script and it's results! + +.. code-block:: python + + import mlx.core as mx + from mlx_sample_extensions import axpby + + a = mx.ones((3, 4)) + b = mx.ones((3, 4)) + c = axpby(a, b, 4.0, 2.0, stream=mx.cpu) + + print(f"c shape: {c.shape}") + print(f"c dtype: {c.dtype}") + print(f"c correctness: {mx.all(c == 6.0).item()}") + +Output: + +.. code-block:: + + c shape: [3, 4] + c dtype: float32 + c correctness: True + +Results +^^^^^^^^^^^^^^^^ + +Let's run a quick benchmark and see how our new ``axpby`` operation compares +with the naive :meth:`simple_axpby` we defined at first on the CPU. + +.. code-block:: python + + import mlx.core as mx + from mlx_sample_extensions import axpby + import time + + mx.set_default_device(mx.cpu) + + def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: + return alpha * x + beta * y + + M = 256 + N = 512 + + x = mx.random.normal((M, N)) + y = mx.random.normal((M, N)) + alpha = 4.0 + beta = 2.0 + + mx.eval((x, y)) + + def bench(f): + # Warm up + for i in range(100): + z = f(x, y, alpha, beta) + mx.eval(z) + + # Timed run + s = time.time() + for i in range(5000): + z = f(x, y, alpha, beta) + mx.eval(z) + e = time.time() + return e - s + + simple_time = bench(simple_axpby) + custom_time = bench(axpby) + + print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s") + +Results: + +.. code-block:: + + Simple axpby: 0.114 s | Custom axpby: 0.109 s + +We see some modest improvements right away! + +This operation is now good to be used to build other operations, +in :class:`mlx.nn.Module` calls, and also as a part of graph +transformations such as :meth:`grad` and :meth:`simplify`! + +Scripts +------- + +.. admonition:: Download the code + + The full example code is available in `mlx-examples `_. + +.. code: `TODO_LINK/extensions`_ + +.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc +.. _Metal: https://developer.apple.com/documentation/metal?language=objc +.. _Metal-cpp: https://developer.apple.com/metal/cpp/ +.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf +.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc +.. _PyBind11: https://pybind11.readthedocs.io/en/stable/ \ No newline at end of file diff --git a/docs/src/examples/linear_regression.rst b/docs/src/examples/linear_regression.rst new file mode 100644 index 000000000..19fc0d435 --- /dev/null +++ b/docs/src/examples/linear_regression.rst @@ -0,0 +1,77 @@ +.. _linear_regression: + +Linear Regression +----------------- + +Let's implement a basic linear regression model as a starting point to +learn MLX. First import the core package and setup some problem metadata: + +.. code-block:: python + + import mlx.core as mx + + num_features = 100 + num_examples = 1_000 + num_iters = 10_000 # iterations of SGD + lr = 0.01 # learning rate for SGD + + +We'll generate a synthetic dataset by: + +1. Sampling the design matrix ``X``. +2. Sampling a ground truth parameter vector ``w_star``. +3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``. + +.. code-block:: python + + # True parameters + w_star = mx.random.normal((num_features,)) + + # Input examples (design matrix) + X = mx.random.normal((num_examples, num_features)) + + # Noisy labels + eps = 1e-2 * mx.random.normal((num_examples,)) + y = X @ w_star + eps + + +We will use SGD to find the optimal weights. To start, define the squared loss +and get the gradient function of the loss with respect to the parameters. + +.. code-block:: python + + def loss_fn(w): + return 0.5 * mx.mean(mx.square(X @ w - y)) + + grad_fn = mx.grad(loss_fn) + +Start the optimization by initializing the parameters ``w`` randomly. Then +repeatedly update the parameters for ``num_iters`` iterations. + +.. code-block:: python + + w = 1e-2 * mx.random.normal((num_features,)) + + for _ in range(num_iters): + grad = grad_fn(w) + w = w - lr * grad + mx.eval(w) + +Finally, compute the loss of the learned parameters and verify that they are +close to the ground truth parameters. + +.. code-block:: python + + loss = loss_fn(w) + error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5 + + print( + f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, " + ) + # Should print something close to: Loss 0.00005, |w-w*| = 0.00364 + +Complete `linear regression +`_ +and `logistic regression +`_ +examples are available in the MLX GitHub repo. diff --git a/docs/src/python/data_types.rst b/docs/src/python/data_types.rst new file mode 100644 index 000000000..cbb5c9a3f --- /dev/null +++ b/docs/src/python/data_types.rst @@ -0,0 +1,52 @@ +.. _data_types: + +:orphan: + +Data Types +========== + +.. currentmodule:: mlx.core + +The default floating point type is ``float32`` and the default integer type is +``int32``. The table below shows supported values for :obj:`Dtype`. + +.. list-table:: Supported Data Types + :widths: 5 3 20 + :header-rows: 1 + + * - Type + - Bytes + - Description + * - ``bool_`` + - 1 + - Boolean (``True``, ``False``) data type + * - ``uint8`` + - 1 + - 8-bit unsigned integer + * - ``uint16`` + - 2 + - 16-bit unsigned integer + * - ``uint32`` + - 4 + - 32-bit unsigned integer + * - ``uint32`` + - 8 + - 32-bit unsigned integer + * - ``int8`` + - 1 + - 8-bit signed integer + * - ``int16`` + - 2 + - 16-bit signed integer + * - ``int32`` + - 4 + - 32-bit signed integer + * - ``int64`` + - 8 + - 64-bit signed integer + * - ``float16`` + - 2 + - 16-bit float, only available with `ARM C language extensions `_ + * - ``float32`` + - 4 + - 32-bit float diff --git a/docs/src/python/devices_and_streams.rst b/docs/src/python/devices_and_streams.rst new file mode 100644 index 000000000..bb9dfae2f --- /dev/null +++ b/docs/src/python/devices_and_streams.rst @@ -0,0 +1,17 @@ +.. _devices_and_streams: + +Devices and Streams +=================== + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + Device + default_device + set_default_device + Stream + default_stream + new_stream + set_default_stream diff --git a/docs/src/python/transforms.rst b/docs/src/python/transforms.rst new file mode 100644 index 000000000..cc8d681d5 --- /dev/null +++ b/docs/src/python/transforms.rst @@ -0,0 +1,16 @@ +.. _transforms: + +Transforms +========== + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + eval + grad + value_and_grad + jvp + vjp + vmap diff --git a/docs/src/python/tree_utils.rst b/docs/src/python/tree_utils.rst new file mode 100644 index 000000000..84d5afa9b --- /dev/null +++ b/docs/src/python/tree_utils.rst @@ -0,0 +1,21 @@ +.. _utils: + +Tree Utils +========== + +In MLX we consider a python tree to be an arbitrarily nested collection of +dictionaries, lists and tuples without cycles. Functions in this module that +return python trees will be using the default python ``dict``, ``list`` and +``tuple`` but they can usually process objects that inherit from any of these. + +.. note:: + Dictionaries should have keys that are valid python identifiers. + +.. currentmodule:: mlx.utils + +.. autosummary:: + :toctree: _autosummary + + tree_flatten + tree_unflatten + tree_map diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt new file mode 100644 index 000000000..71c6472aa --- /dev/null +++ b/examples/cpp/CMakeLists.txt @@ -0,0 +1,10 @@ +function(build_example SRCFILE) + get_filename_component(src_name ${SRCFILE} NAME_WE) + set(target "${src_name}") + add_executable(${target} ${SRCFILE}) + target_link_libraries(${target} PRIVATE mlx) +endfunction(build_example) + +build_example(tutorial.cpp) +build_example(linear_regression.cpp) +build_example(logistic_regression.cpp) diff --git a/examples/cpp/linear_regression.cpp b/examples/cpp/linear_regression.cpp new file mode 100644 index 000000000..e6897b7a7 --- /dev/null +++ b/examples/cpp/linear_regression.cpp @@ -0,0 +1,52 @@ +#include +#include +#include + +#include "mlx/mlx.h" +#include "timer.h" + +/** + * An example of linear regression with MLX. + */ +using namespace mlx::core; + +int main() { + int num_features = 100; + int num_examples = 1'000; + int num_iters = 10'000; + float learning_rate = 0.01; + + // True parameters + auto w_star = random::normal({num_features}); + + // The input examples (design matrix) + auto X = random::normal({num_examples, num_features}); + + // Noisy labels + auto eps = 1e-2 * random::normal({num_examples}); + auto y = matmul(X, w_star) + eps; + + // Initialize random parameters + array w = 1e-2 * random::normal({num_features}); + + auto loss_fn = [&](array w) { + auto yhat = matmul(X, w); + return (0.5f / num_examples) * sum(square(yhat - y)); + }; + + auto grad_fn = grad(loss_fn); + + auto tic = timer::time(); + for (int it = 0; it < num_iters; ++it) { + auto grad = grad_fn(w); + w = w - learning_rate * grad; + eval(w); + } + auto toc = timer::time(); + + auto loss = loss_fn(w); + auto error_norm = std::sqrt(sum(square(w - w_star)).item()); + auto throughput = num_iters / timer::seconds(toc - tic); + std::cout << "Loss " << loss << ", |w - w*| = " << error_norm + << ", Throughput " << throughput << " (it/s)." << std::endl; +} diff --git a/mlx/3rdparty/pocketfft.h b/mlx/3rdparty/pocketfft.h new file mode 100644 index 000000000..03a45897a --- /dev/null +++ b/mlx/3rdparty/pocketfft.h @@ -0,0 +1,3581 @@ +/* +This file is part of pocketfft. + +Copyright (C) 2010-2022 Max-Planck-Society +Copyright (C) 2019-2020 Peter Bell + +For the odd-sized DCT-IV transforms: + Copyright (C) 2003, 2007-14 Matteo Frigo + Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology + +Authors: Martin Reinecke, Peter Bell + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. +* Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef POCKETFFT_HDRONLY_H +#define POCKETFFT_HDRONLY_H + +#ifndef __cplusplus +#error This file is C++ and requires a C++ compiler. +#endif + +#if !(__cplusplus >= 201103L || _MSVC_LANG+0L >= 201103L) +#error This file requires at least C++11 support. +#endif + +#ifndef POCKETFFT_CACHE_SIZE +#define POCKETFFT_CACHE_SIZE 0 +#endif + +#include +#include +#include +#include +#include +#include +#include +#if POCKETFFT_CACHE_SIZE!=0 +#include +#include +#endif + +#ifndef POCKETFFT_NO_MULTITHREADING +#include +#include +#include +#include +#include +#include +#include + +#ifdef POCKETFFT_PTHREADS +# include +#endif +#endif + +#if defined(__GNUC__) +#define POCKETFFT_NOINLINE __attribute__((noinline)) +#define POCKETFFT_RESTRICT __restrict__ +#elif defined(_MSC_VER) +#define POCKETFFT_NOINLINE __declspec(noinline) +#define POCKETFFT_RESTRICT __restrict +#else +#define POCKETFFT_NOINLINE +#define POCKETFFT_RESTRICT +#endif + +namespace pocketfft { + +namespace detail { +using std::size_t; +using std::ptrdiff_t; + +// Always use std:: for functions +template T cos(T) = delete; +template T sin(T) = delete; +template T sqrt(T) = delete; + +using shape_t = std::vector; +using stride_t = std::vector; + +constexpr bool FORWARD = true, + BACKWARD = false; + +// only enable vector support for gcc>=5.0 and clang>=5.0 +#ifndef POCKETFFT_NO_VECTORS +#define POCKETFFT_NO_VECTORS +#if defined(__INTEL_COMPILER) +// do nothing. This is necessary because this compiler also sets __GNUC__. +#elif defined(__clang__) +// AppleClang has their own version numbering +#ifdef __apple_build_version__ +# if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1) +# undef POCKETFFT_NO_VECTORS +# endif +#elif __clang_major__ >= 5 +# undef POCKETFFT_NO_VECTORS +#endif +#elif defined(__GNUC__) +#if __GNUC__>=5 +#undef POCKETFFT_NO_VECTORS +#endif +#endif +#endif + +template struct VLEN { static constexpr size_t val=1; }; + +#ifndef POCKETFFT_NO_VECTORS +#if (defined(__AVX512F__)) +template<> struct VLEN { static constexpr size_t val=16; }; +template<> struct VLEN { static constexpr size_t val=8; }; +#elif (defined(__AVX__)) +template<> struct VLEN { static constexpr size_t val=8; }; +template<> struct VLEN { static constexpr size_t val=4; }; +#elif (defined(__SSE2__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__VSX__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#else +#define POCKETFFT_NO_VECTORS +#endif +#endif + +// the __MINGW32__ part in the conditional below works around the problem that +// the standard C++ library on Windows does not provide aligned_alloc() even +// though the MinGW compiler and MSVC may advertise C++17 compliance. +#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER)) +inline void *aligned_alloc(size_t align, size_t size) + { + // aligned_alloc() requires that the requested size is a multiple of "align" + void *ptr = ::aligned_alloc(align,(size+align-1)&(~(align-1))); + if (!ptr) throw std::bad_alloc(); + return ptr; + } +inline void aligned_dealloc(void *ptr) + { free(ptr); } +#else // portable emulation +inline void *aligned_alloc(size_t align, size_t size) + { + align = std::max(align, alignof(max_align_t)); + void *ptr = malloc(size+align); + if (!ptr) throw std::bad_alloc(); + void *res = reinterpret_cast + ((reinterpret_cast(ptr) & ~(uintptr_t(align-1))) + uintptr_t(align)); + (reinterpret_cast(res))[-1] = ptr; + return res; + } +inline void aligned_dealloc(void *ptr) + { if (ptr) free((reinterpret_cast(ptr))[-1]); } +#endif + +template class arr + { + private: + T *p; + size_t sz; + +#if defined(POCKETFFT_NO_VECTORS) + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *res = malloc(num*sizeof(T)); + if (!res) throw std::bad_alloc(); + return reinterpret_cast(res); + } + static void dealloc(T *ptr) + { free(ptr); } +#else + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *ptr = aligned_alloc(64, num*sizeof(T)); + return static_cast(ptr); + } + static void dealloc(T *ptr) + { aligned_dealloc(ptr); } +#endif + + public: + arr() : p(0), sz(0) {} + arr(size_t n) : p(ralloc(n)), sz(n) {} + arr(arr &&other) + : p(other.p), sz(other.sz) + { other.p=nullptr; other.sz=0; } + ~arr() { dealloc(p); } + + void resize(size_t n) + { + if (n==sz) return; + dealloc(p); + p = ralloc(n); + sz = n; + } + + T &operator[](size_t idx) { return p[idx]; } + const T &operator[](size_t idx) const { return p[idx]; } + + T *data() { return p; } + const T *data() const { return p; } + + size_t size() const { return sz; } + }; + +template struct cmplx { + T r, i; + cmplx() {} + cmplx(T r_, T i_) : r(r_), i(i_) {} + void Set(T r_, T i_) { r=r_; i=i_; } + void Set(T r_) { r=r_; i=T(0); } + cmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator*= (T2 other) + { r*=other; i*=other; return *this; } + templatecmplx &operator*= (const cmplx &other) + { + T tmp = r*other.r - i*other.i; + i = r*other.i + i*other.r; + r = tmp; + return *this; + } + templatecmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator-= (const cmplx &other) + { r-=other.r; i-=other.i; return *this; } + template auto operator* (const T2 &other) const + -> cmplx + { return {r*other, i*other}; } + template auto operator+ (const cmplx &other) const + -> cmplx + { return {r+other.r, i+other.i}; } + template auto operator- (const cmplx &other) const + -> cmplx + { return {r-other.r, i-other.i}; } + template auto operator* (const cmplx &other) const + -> cmplx + { return {r*other.r-i*other.i, r*other.i + i*other.r}; } + template auto special_mul (const cmplx &other) const + -> cmplx + { + using Tres = cmplx; + return fwd ? Tres(r*other.r+i*other.i, i*other.r-r*other.i) + : Tres(r*other.r-i*other.i, r*other.i+i*other.r); + } +}; +template inline void PM(T &a, T &b, T c, T d) + { a=c+d; b=c-d; } +template inline void PMINPLACE(T &a, T &b) + { T t = a; a+=b; b=t-b; } +template inline void MPINPLACE(T &a, T &b) + { T t = a; a-=b; b=t+b; } +template cmplx conj(const cmplx &a) + { return {a.r, -a.i}; } +template void special_mul (const cmplx &v1, const cmplx &v2, cmplx &res) + { + res = fwd ? cmplx(v1.r*v2.r+v1.i*v2.i, v1.i*v2.r-v1.r*v2.i) + : cmplx(v1.r*v2.r-v1.i*v2.i, v1.r*v2.i+v1.i*v2.r); + } + +template void ROT90(cmplx &a) + { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; } +template void ROTX90(cmplx &a) + { auto tmp_= fwd ? -a.r : a.r; a.r = fwd ? a.i : -a.i; a.i=tmp_; } + +// +// twiddle factor section +// +template class sincos_2pibyn + { + private: + using Thigh = typename std::conditional<(sizeof(T)>sizeof(double)), T, double>::type; + size_t N, mask, shift; + arr> v1, v2; + + static cmplx calc(size_t x, size_t n, Thigh ang) + { + x<<=3; + if (x<4*n) // first half + { + if (x<2*n) // first quadrant + { + if (x(std::cos(Thigh(x)*ang), std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), std::cos(Thigh(2*n-x)*ang)); + } + else // second quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), std::sin(Thigh(2*n-x)*ang)); + } + } + else + { + x=8*n-x; + if (x<2*n) // third quadrant + { + if (x(std::cos(Thigh(x)*ang), -std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), -std::cos(Thigh(2*n-x)*ang)); + } + else // fourth quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), -std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), -std::sin(Thigh(2*n-x)*ang)); + } + } + } + + public: + POCKETFFT_NOINLINE sincos_2pibyn(size_t n) + : N(n) + { + constexpr auto pi = 3.141592653589793238462643383279502884197L; + Thigh ang = Thigh(0.25L*pi/n); + size_t nval = (n+2)/2; + shift = 1; + while((size_t(1)< operator[](size_t idx) const + { + if (2*idx<=N) + { + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), T(x1.r*x2.i+x1.i*x2.r)); + } + idx = N-idx; + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), -T(x1.r*x2.i+x1.i*x2.r)); + } + }; + +struct util // hack to avoid duplicate symbols + { + static POCKETFFT_NOINLINE size_t largest_prime_factor (size_t n) + { + size_t res=1; + while ((n&1)==0) + { res=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { res=x; n/=x; } + if (n>1) res=n; + return res; + } + + static POCKETFFT_NOINLINE double cost_guess (size_t n) + { + constexpr double lfp=1.1; // penalty for non-hardcoded larger factors + size_t ni=n; + double result=0.; + while ((n&1)==0) + { result+=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { + result+= (x<=5) ? double(x) : lfp*double(x); // penalize larger prime factors + n/=x; + } + if (n>1) result+=(n<=5) ? double(n) : lfp*double(n); + return result*double(ni); + } + + /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) + { + if (n<=12) return n; + + size_t bestfac=2*n; + for (size_t f11=1; f11n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + + /* returns the smallest composite of 2, 3, 5 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_real(size_t n) + { + if (n<=6) return n; + + size_t bestfac=2*n; + for (size_t f5=1; f5n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + + static size_t prod(const shape_t &shape) + { + size_t res=1; + for (auto sz: shape) + res*=sz; + return res; + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace) + { + auto ndim = shape.size(); + if (ndim<1) throw std::runtime_error("ndim must be >= 1"); + if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim)) + throw std::runtime_error("stride dimension mismatch"); + if (inplace && (stride_in!=stride_out)) + throw std::runtime_error("stride mismatch"); + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + const shape_t &axes) + { + sanity_check(shape, stride_in, stride_out, inplace); + auto ndim = shape.size(); + shape_t tmp(ndim,0); + for (auto ax : axes) + { + if (ax>=ndim) throw std::invalid_argument("bad axis number"); + if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly"); + } + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + size_t axis) + { + sanity_check(shape, stride_in, stride_out, inplace); + if (axis>=shape.size()) throw std::invalid_argument("bad axis number"); + } + +#ifdef POCKETFFT_NO_MULTITHREADING + static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/, + size_t /*axis*/, size_t /*vlen*/) + { return 1; } +#else + static size_t thread_count (size_t nthreads, const shape_t &shape, + size_t axis, size_t vlen) + { + if (nthreads==1) return 1; + size_t size = prod(shape); + size_t parallel = size / (shape[axis] * vlen); + if (shape[axis] < 1000) + parallel /= 4; + size_t max_threads = nthreads == 0 ? + std::thread::hardware_concurrency() : nthreads; + return std::max(size_t(1), std::min(parallel, max_threads)); + } +#endif + }; + +namespace threading { + +#ifdef POCKETFFT_NO_MULTITHREADING + +constexpr inline size_t thread_id() { return 0; } +constexpr inline size_t num_threads() { return 1; } + +template +void thread_map(size_t /* nthreads */, Func f) + { f(); } + +#else + +inline size_t &thread_id() + { + static thread_local size_t thread_id_=0; + return thread_id_; + } +inline size_t &num_threads() + { + static thread_local size_t num_threads_=1; + return num_threads_; + } +static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); + +class latch + { + std::atomic num_left_; + std::mutex mut_; + std::condition_variable completed_; + using lock_t = std::unique_lock; + + public: + latch(size_t n): num_left_(n) {} + + void count_down() + { + lock_t lock(mut_); + if (--num_left_) + return; + completed_.notify_all(); + } + + void wait() + { + lock_t lock(mut_); + completed_.wait(lock, [this]{ return is_ready(); }); + } + bool is_ready() { return num_left_ == 0; } + }; + +template class concurrent_queue + { + std::queue q_; + std::mutex mut_; + std::atomic size_; + using lock_t = std::lock_guard; + + public: + + void push(T val) + { + lock_t lock(mut_); + ++size_; + q_.push(std::move(val)); + } + + bool try_pop(T &val) + { + if (size_ == 0) return false; + lock_t lock(mut_); + // Queue might have been emptied while we acquired the lock + if (q_.empty()) return false; + + val = std::move(q_.front()); + --size_; + q_.pop(); + return true; + } + + bool empty() const { return size_==0; } + }; + +// C++ allocator with support for over-aligned types +template struct aligned_allocator + { + using value_type = T; + template + aligned_allocator(const aligned_allocator&) {} + aligned_allocator() = default; + + T *allocate(size_t n) + { + void* mem = aligned_alloc(alignof(T), n*sizeof(T)); + return static_cast(mem); + } + + void deallocate(T *p, size_t /*n*/) + { aligned_dealloc(p); } + }; + +class thread_pool + { + // A reasonable guess, probably close enough for most hardware + static constexpr size_t cache_line_size = 64; + struct alignas(cache_line_size) worker + { + std::thread thread; + std::condition_variable work_ready; + std::mutex mut; + std::atomic_flag busy_flag = ATOMIC_FLAG_INIT; + std::function work; + + void worker_main( + std::atomic &shutdown_flag, + std::atomic &unscheduled_tasks, + concurrent_queue> &overflow_work) + { + using lock_t = std::unique_lock; + bool expect_work = true; + while (!shutdown_flag || expect_work) + { + std::function local_work; + if (expect_work || unscheduled_tasks == 0) + { + lock_t lock(mut); + // Wait until there is work to be executed + work_ready.wait(lock, [&]{ return (work || shutdown_flag); }); + local_work.swap(work); + expect_work = false; + } + + bool marked_busy = false; + if (local_work) + { + marked_busy = true; + local_work(); + } + + if (!overflow_work.empty()) + { + if (!marked_busy && busy_flag.test_and_set()) + { + expect_work = true; + continue; + } + marked_busy = true; + + while (overflow_work.try_pop(local_work)) + { + --unscheduled_tasks; + local_work(); + } + } + + if (marked_busy) busy_flag.clear(); + } + } + }; + + concurrent_queue> overflow_work_; + std::mutex mut_; + std::vector> workers_; + std::atomic shutdown_; + std::atomic unscheduled_tasks_; + using lock_t = std::lock_guard; + + void create_threads() + { + lock_t lock(mut_); + size_t nthreads=workers_.size(); + for (size_t i=0; ibusy_flag.clear(); + worker->work = nullptr; + worker->thread = std::thread([worker, this] + { + worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_); + }); + } + catch (...) + { + shutdown_locked(); + throw; + } + } + } + + void shutdown_locked() + { + shutdown_ = true; + for (auto &worker : workers_) + worker.work_ready.notify_all(); + + for (auto &worker : workers_) + if (worker.thread.joinable()) + worker.thread.join(); + } + + public: + explicit thread_pool(size_t nthreads): + workers_(nthreads) + { create_threads(); } + + thread_pool(): thread_pool(max_threads) {} + + ~thread_pool() { shutdown(); } + + void submit(std::function work) + { + lock_t lock(mut_); + if (shutdown_) + throw std::runtime_error("Work item submitted after shutdown"); + + ++unscheduled_tasks_; + + // First check for any idle workers and wake those + for (auto &worker : workers_) + if (!worker.busy_flag.test_and_set()) + { + --unscheduled_tasks_; + { + lock_t lock(worker.mut); + worker.work = std::move(work); + } + worker.work_ready.notify_one(); + return; + } + + // If no workers were idle, push onto the overflow queue for later + overflow_work_.push(std::move(work)); + } + + void shutdown() + { + lock_t lock(mut_); + shutdown_locked(); + } + + void restart() + { + shutdown_ = false; + create_threads(); + } + }; + +inline thread_pool & get_pool() + { + static thread_pool pool; +#ifdef POCKETFFT_PTHREADS + static std::once_flag f; + std::call_once(f, + []{ + pthread_atfork( + +[]{ get_pool().shutdown(); }, // prepare + +[]{ get_pool().restart(); }, // parent + +[]{ get_pool().restart(); } // child + ); + }); +#endif + + return pool; + } + +/** Map a function f over nthreads */ +template +void thread_map(size_t nthreads, Func f) + { + if (nthreads == 0) + nthreads = max_threads; + + if (nthreads == 1) + { f(); return; } + + auto & pool = get_pool(); + latch counter(nthreads); + std::exception_ptr ex; + std::mutex ex_mut; + for (size_t i=0; i lock(ex_mut); + ex = std::current_exception(); + } + counter.count_down(); + }); + } + counter.wait(); + if (ex) + std::rethrow_exception(ex); + } + +#endif + +} + +// +// complex FFTPACK transforms +// + +template class cfftp + { + private: + struct fctdata + { + size_t fct; + cmplx *tw, *tws; + }; + + size_t length; + arr> mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +template void pass2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(CC(i,0,k)-CC(i,1,k),WA(0,i),CH(i,k,1)); + } + } + } + +#define POCKETFFT_PREP3(idx) \ + T t0 = CC(idx,0,k), t1, t2; \ + PM (t1,t2,CC(idx,1,k),CC(idx,2,k)); \ + CH(idx,k,0)=t0+t1; +#define POCKETFFT_PARTSTEP3a(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\ + } +#define POCKETFFT_PARTSTEP3b(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass3 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r=-0.5, + tw1i= (fwd ? -1: 1) * T0(0.8660254037844386467637231707529362L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void pass4 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + else + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + for (size_t i=1; i(t4); + CH(i,k,0) = t2+t3; + special_mul(t1+t4,WA(0,i),CH(i,k,1)); + special_mul(t2-t3,WA(1,i),CH(i,k,2)); + special_mul(t1-t4,WA(2,i),CH(i,k,3)); + } + } + } + +#define POCKETFFT_PREP5(idx) \ + T t0 = CC(idx,0,k), t1, t2, t3, t4; \ + PM (t1,t4,CC(idx,1,k),CC(idx,4,k)); \ + PM (t2,t3,CC(idx,2,k),CC(idx,3,k)); \ + CH(idx,k,0).r=t0.r+t1.r+t2.r; \ + CH(idx,k,0).i=t0.i+t1.i+t2.i; + +#define POCKETFFT_PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb); \ + } + +#define POCKETFFT_PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb,da,db; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass5 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.3090169943749474241022934171828191L), + tw1i= (fwd ? -1: 1) * T0(0.9510565162951535721164393333793821L), + tw2r= T0(-0.8090169943749474241022934171828191L), + tw2i= (fwd ? -1: 1) * T0(0.5877852522924731291687059546390728L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass7(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.6234898018587335305250048840042398L), + tw1i= (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L), + tw2r= T0(-0.2225209339563144042889025644967948L), + tw2i= (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L), + tw3r= T0(-0.9009688679024191262361023195074451L), + tw3i= (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+7*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void ROTX45(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.r+a.i); a.i=hsqt2*(a.i-tmp_); } + else + { auto tmp_=a.r; a.r=hsqt2*(a.r-a.i); a.i=hsqt2*(a.i+tmp_); } + } +template void ROTX135(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.i-a.r); a.i=hsqt2*(-tmp_-a.i); } + else + { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); } + } + +template void pass8 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+8*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + else + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + for (size_t i=1; i(a7); + PMINPLACE(a1,a3); + ROTX90(a3); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + PM(a0,a4,CC(i,0,k),CC(i,4,k)); + PM(a2,a6,CC(i,2,k),CC(i,6,k)); + PMINPLACE(a0,a2); + CH(i,k,0) = a0+a1; + special_mul(a0-a1,WA(3,i),CH(i,k,4)); + special_mul(a2+a3,WA(1,i),CH(i,k,2)); + special_mul(a2-a3,WA(5,i),CH(i,k,6)); + ROTX90(a6); + PMINPLACE(a4,a6); + special_mul(a4+a5,WA(0,i),CH(i,k,1)); + special_mul(a4-a5,WA(4,i),CH(i,k,5)); + special_mul(a6+a7,WA(2,i),CH(i,k,3)); + special_mul(a6-a7,WA(6,i),CH(i,k,7)); + } + } + } + + +#define POCKETFFT_PREP11(idx) \ + T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \ + PM (t2,t11,CC(idx,1,k),CC(idx,10,k)); \ + PM (t3,t10,CC(idx,2,k),CC(idx, 9,k)); \ + PM (t4,t9 ,CC(idx,3,k),CC(idx, 8,k)); \ + PM (t5,t8 ,CC(idx,4,k),CC(idx, 7,k)); \ + PM (t6,t7 ,CC(idx,5,k),CC(idx, 6,k)); \ + CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r+t5.r+t6.r; \ + CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i+t5.i+t6.i; + +#define POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,out1,out2) \ + { \ + T ca = t1 + t2*x1 + t3*x2 + t4*x3 + t5*x4 +t6*x5, \ + cb; \ + cb.i=y1*t11.r y2*t10.r y3*t9.r y4*t8.r y5*t7.r; \ + cb.r=-(y1*t11.i y2*t10.i y3*t9.i y4*t8.i y5*t7.i ); \ + PM(out1,out2,ca,cb); \ + } +#define POCKETFFT_PARTSTEP11a(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,CH(0,k,u1),CH(0,k,u2)) +#define POCKETFFT_PARTSTEP11(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + { \ + T da,db; \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,da,db) \ + special_mul(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass11 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.8412535328311811688618116489193677L), + tw1i= (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L), + tw2r= T0(0.4154150130018864255292741492296232L), + tw2i= (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L), + tw3r= T0(-0.1423148382732851404437926686163697L), + tw3i= (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L), + tw4r= T0(-0.6548607339452850640569250724662936L), + tw4i= (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L), + tw5r= T0(-0.9594929736144973898903680570663277L), + tw5i= (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+11*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void passg (size_t ido, size_t ip, + size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa, + const cmplx * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph = (ip+1)/2; + size_t idl1 = ido*l1; + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto CX2 = [cc, idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch, idl1](size_t a, size_t b) -> const T& + { return ch[a+idl1*b]; }; + + arr> wal(ip); + wal[0] = cmplx(1., 0.); + for (size_t i=1; i(csarr[i].r,fwd ? -csarr[i].i : csarr[i].i); + + for (size_t k=0; kip) iwal-=ip; + cmplx xwal=wal[iwal]; + iwal+=l; if (iwal>ip) iwal-=ip; + cmplx xwal2=wal[iwal]; + for (size_t ik=0; ikip) iwal-=ip; + cmplx xwal=wal[iwal]; + for (size_t ik=0; ik(x1,wa[idij],CX(i,k,j)); + idij=(jc-1)*(ido-1)+i-1; + special_mul(x2,wa[idij],CX(i,k,jc)); + } + } + } + } + +template void pass_all(T c[], T0 fct) const + { + if (length==1) { c[0]*=fct; return; } + size_t l1=1; + arr ch(length); + T *p1=c, *p2=ch.data(); + + for(size_t k1=0; k1 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==8) + pass8(ido, l1, p1, p2, fact[k1].tw); + else if(ip==2) + pass2(ido, l1, p1, p2, fact[k1].tw); + else if(ip==3) + pass3 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==5) + pass5 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==7) + pass7 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==11) + pass11 (ido, l1, p1, p2, fact[k1].tw); + else + { + passg(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws); + std::swap(p1,p2); + } + std::swap(p1,p2); + l1=l2; + } + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool fwd) const + { fwd ? pass_all(c, fct) : pass_all(c, fct); } + + private: + POCKETFFT_NOINLINE void factorize() + { + size_t len=length; + while ((len&7)==0) + { add_factor(8); len>>=3; } + while ((len&3)==0) + { add_factor(4); len>>=2; } + if ((len&1)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsize=0, l1=1; + for (size_t k=0; k11) + twsize+=ip; + l1*=ip; + } + return twsize; + } + + void comp_twiddle() + { + sincos_2pibyn twiddle(length); + size_t l1=1; + size_t memofs=0; + for (size_t k=0; k11) + { + fact[k].tws=mem.data()+memofs; + memofs+=ip; + for (size_t j=0; j class rfftp + { + private: + struct fctdata + { + size_t fct; + T0 *tw, *tws; + }; + + size_t length; + arr mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +/* (a+ib) = conj(c+id) * (e+if) */ +template inline void MULPM + (T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const + { a=c*e+d*f; b=c*f-d*e; } + +template void radf2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+2*c)]; }; + + for (size_t k=0; k void radf3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+3*c)]; }; + + for (size_t k=0; k void radf4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+4*c)]; }; + + for (size_t k=0; k void radf5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+5*c)]; }; + + for (size_t k=0; k void radfg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1] (size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1] (size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1] (size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + if (ido>1) + { + for (size_t j=1, jc=ip-1; j=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar=csarr[2*iang], ai=csarr[2*iang+1]; + for (size_t ik=0; ik void radb2(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radbg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/ 2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1](size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + for (size_t k=0; kip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 war=csarr[2*iang], wai=csarr[2*iang+1]; + for (size_t ik=0; ik void copy_and_norm(T *c, T *p1, T0 fct) const + { + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool r2hc) const + { + if (length==1) { c[0]*=fct; return; } + size_t nf=fact.size(); + arr ch(length); + T *p1=c, *p2=ch.data(); + + if (r2hc) + for(size_t k1=0, l1=length; k1>=2; } + if ((len%2)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsz=0, l1=1; + for (size_t k=0; k5) twsz+=2*ip; + l1*=ip; + } + return twsz; + } + + void comp_twiddle() + { + sincos_2pibyn twid(length); + size_t l1=1; + T0 *ptr=mem.data(); + for (size_t k=0; k5) // special factors required by *g functions + { + fact[k].tws=ptr; ptr+=2*ip; + fact[k].tws[0] = 1.; + fact[k].tws[1] = 0.; + for (size_t i=2, ic=2*ip-2; i<=ic; i+=2, ic-=2) + { + fact[k].tws[i ] = twid[i/2*(length/ip)].r; + fact[k].tws[i+1] = twid[i/2*(length/ip)].i; + fact[k].tws[ic] = twid[i/2*(length/ip)].r; + fact[k].tws[ic+1] = -twid[i/2*(length/ip)].i; + } + } + l1*=ip; + } + } + + public: + POCKETFFT_NOINLINE rfftp(size_t length_) + : length(length_) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + if (length==1) return; + factorize(); + mem.resize(twsize()); + comp_twiddle(); + } +}; + +// +// complex Bluestein transforms +// + +template class fftblue + { + private: + size_t n, n2; + cfftp plan; + arr> mem; + cmplx *bk, *bkf; + + template void fft(cmplx c[], T0 fct) const + { + arr> akf(n2); + + /* initialize a_k and FFT it */ + for (size_t m=0; m(c[m],bk[m],akf[m]); + auto zero = akf[0]*T0(0); + for (size_t m=n; m(bkf[0]); + for (size_t m=1; m<(n2+1)/2; ++m) + { + akf[m] = akf[m].template special_mul(bkf[m]); + akf[n2-m] = akf[n2-m].template special_mul(bkf[m]); + } + if ((n2&1)==0) + akf[n2/2] = akf[n2/2].template special_mul(bkf[n2/2]); + + /* inverse FFT */ + plan.exec (akf.data(),1.,false); + + /* multiply by b_k */ + for (size_t m=0; m(bk[m])*fct; + } + + public: + POCKETFFT_NOINLINE fftblue(size_t length) + : n(length), n2(util::good_size_cmplx(n*2-1)), plan(n2), mem(n+n2/2+1), + bk(mem.data()), bkf(mem.data()+n) + { + /* initialize b_k */ + sincos_2pibyn tmp(2*n); + bk[0].Set(1, 0); + + size_t coeff=0; + for (size_t m=1; m=2*n) coeff-=2*n; + bk[m] = tmp[coeff]; + } + + /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */ + arr> tbkf(n2); + T0 xn2 = T0(1)/T0(n2); + tbkf[0] = bk[0]*xn2; + for (size_t m=1; m void exec(cmplx c[], T0 fct, bool fwd) const + { fwd ? fft(c,fct) : fft(c,fct); } + + template void exec_r(T c[], T0 fct, bool fwd) + { + arr> tmp(n); + if (fwd) + { + auto zero = T0(0)*c[0]; + for (size_t m=0; m(tmp.data(),fct); + c[0] = tmp[0].r; + std::copy_n (&tmp[1].r, n-1, &c[1]); + } + else + { + tmp[0].Set(c[0],c[0]*0); + std::copy_n (c+1, n-1, &tmp[1].r); + if ((n&1)==0) tmp[n/2].i=T0(0)*c[0]; + for (size_t m=1; 2*m(tmp.data(),fct); + for (size_t m=0; m class pocketfft_c + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_c(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new cfftp(length)); + return; + } + double comp1 = util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new cfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(cmplx c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec(c,fct,fwd); } + + size_t length() const { return len; } + }; + +// +// flexible (FFTPACK/Bluestein) real-valued 1D transform +// + +template class pocketfft_r + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_r(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new rfftp(length)); + return; + } + double comp1 = 0.5*util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new rfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec_r(c,fct,fwd); } + + size_t length() const { return len; } + }; + + +// +// sine/cosine transforms +// + +template class T_dct1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dct1(size_t length) + : fftplan(2*(length-1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int /*type*/, bool /*cosine*/) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=fftplan.length(), n=N/2+1; + if (ortho) + { c[0]*=sqrt2; c[n-1]*=sqrt2; } + arr tmp(N); + tmp[0] = c[0]; + for (size_t i=1; i class T_dst1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dst1(size_t length) + : fftplan(2*(length+1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool /*cosine*/) const + { + size_t N=fftplan.length(), n=N/2-1; + arr tmp(N); + tmp[0] = tmp[n+1] = c[0]*0; + for (size_t i=0; i class T_dcst23 + { + private: + pocketfft_r fftplan; + std::vector twiddle; + + public: + POCKETFFT_NOINLINE T_dcst23(size_t length) + : fftplan(length), twiddle(length) + { + sincos_2pibyn tw(4*length); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int type, bool cosine) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + size_t NS2 = (N+1)/2; + if (type==2) + { + if (!cosine) + for (size_t k=1; k class T_dcst4 + { + private: + size_t N; + std::unique_ptr> fft; + std::unique_ptr> rfft; + arr> C2; + + public: + POCKETFFT_NOINLINE T_dcst4(size_t length) + : N(length), + fft((N&1) ? nullptr : new pocketfft_c(N/2)), + rfft((N&1)? new pocketfft_r(N) : nullptr), + C2((N&1) ? 0 : N/2) + { + if ((N&1)==0) + { + sincos_2pibyn tw(16*N); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool cosine) const + { + size_t n2 = N/2; + if (!cosine) + for (size_t k=0, kc=N-1; k y(N); + { + size_t i=0, m=n2; + for (; mexec(y.data(), fct, true); + { + auto SGN = [](size_t i) + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + return (i&2) ? -sqrt2 : sqrt2; + }; + c[n2] = y[0]*SGN(n2+1); + size_t i=0, i1=1, k=1; + for (; k> y(n2); + for(size_t i=0; iexec(y.data(), fct, true); + for(size_t i=0, ic=n2-1; i std::shared_ptr get_plan(size_t length) + { +#if POCKETFFT_CACHE_SIZE==0 + return std::make_shared(length); +#else + constexpr size_t nmax=POCKETFFT_CACHE_SIZE; + static std::array, nmax> cache; + static std::array last_access{{0}}; + static size_t access_counter = 0; + static std::mutex mut; + + auto find_in_cache = [&]() -> std::shared_ptr + { + for (size_t i=0; ilength()==length)) + { + // no need to update if this is already the most recent entry + if (last_access[i]!=access_counter) + { + last_access[i] = ++access_counter; + // Guard against overflow + if (access_counter == 0) + last_access.fill(0); + } + return cache[i]; + } + + return nullptr; + }; + + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + } + auto plan = std::make_shared(length); + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + + size_t lru = 0; + for (size_t i=1; i class cndarr: public arr_info + { + protected: + const char *d; + + public: + cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_) + : arr_info(shape_, stride_), + d(reinterpret_cast(data_)) {} + const T &operator[](ptrdiff_t ofs) const + { return *reinterpret_cast(d+ofs); } + }; + +template class ndarr: public cndarr + { + public: + ndarr(void *data_, const shape_t &shape_, const stride_t &stride_) + : cndarr::cndarr(const_cast(data_), shape_, stride_) + {} + T &operator[](ptrdiff_t ofs) + { return *reinterpret_cast(const_cast(cndarr::d+ofs)); } + }; + +template class multi_iter + { + private: + shape_t pos; + const arr_info &iarr, &oarr; + ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; + size_t idim, rem; + + void advance_i() + { + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + if (i==idim) continue; + p_ii += iarr.stride(i); + p_oi += oarr.stride(i); + if (++pos[i] < iarr.shape(i)) + return; + pos[i] = 0; + p_ii -= ptrdiff_t(iarr.shape(i))*iarr.stride(i); + p_oi -= ptrdiff_t(oarr.shape(i))*oarr.stride(i); + } + } + + public: + multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_) + : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), + str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), + idim(idim_), rem(iarr.size()/iarr.shape(idim)) + { + auto nshares = threading::num_threads(); + if (nshares==1) return; + if (nshares==0) throw std::runtime_error("can't run with zero threads"); + auto myshare = threading::thread_id(); + if (myshare>=nshares) throw std::runtime_error("impossible share requested"); + size_t nbase = rem/nshares; + size_t additional = rem%nshares; + size_t lo = myshare*nbase + ((myshare=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (++pos[i] < arr.shape(i)) + return; + pos[i] = 0; + p -= ptrdiff_t(arr.shape(i))*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + size_t remaining() const { return rem; } + }; + +class rev_iter + { + private: + shape_t pos; + const arr_info &arr; + std::vector rev_axis; + std::vector rev_jump; + size_t last_axis, last_size; + shape_t shp; + ptrdiff_t p, rp; + size_t rem; + + public: + rev_iter(const arr_info &arr_, const shape_t &axes) + : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), + rev_jump(arr_.ndim(), 1), p(0), rp(0) + { + for (auto ax: axes) + rev_axis[ax]=1; + last_axis = axes.back(); + last_size = arr.shape(last_axis)/2 + 1; + shp = arr.shape(); + shp[last_axis] = last_size; + rem=1; + for (auto i: shp) + rem *= i; + } + void advance() + { + --rem; + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (!rev_axis[i]) + rp += arr.stride(i); + else + { + rp -= arr.stride(i); + if (rev_jump[i]) + { + rp += ptrdiff_t(arr.shape(i))*arr.stride(i); + rev_jump[i] = 0; + } + } + if (++pos[i] < shp[i]) + return; + pos[i] = 0; + p -= ptrdiff_t(shp[i])*arr.stride(i); + if (rev_axis[i]) + { + rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i); + rev_jump[i] = 1; + } + else + rp -= ptrdiff_t(shp[i])*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + ptrdiff_t rev_ofs() const { return rp; } + size_t remaining() const { return rem; } + }; + +template struct VTYPE {}; +template using vtype_t = typename VTYPE::type; + +#ifndef POCKETFFT_NO_VECTORS +template<> struct VTYPE + { + using type = float __attribute__ ((vector_size (VLEN::val*sizeof(float)))); + }; +template<> struct VTYPE + { + using type = double __attribute__ ((vector_size (VLEN::val*sizeof(double)))); + }; +template<> struct VTYPE + { + using type = long double __attribute__ ((vector_size (VLEN::val*sizeof(long double)))); + }; +#endif + +template arr alloc_tmp(const shape_t &shape, + size_t axsize, size_t elemsize) + { + auto othersize = util::prod(shape)/axsize; + auto tmpsize = axsize*((othersize>=VLEN::val) ? VLEN::val : 1); + return arr(tmpsize*elemsize); + } +template arr alloc_tmp(const shape_t &shape, + const shape_t &axes, size_t elemsize) + { + size_t fullsize=util::prod(shape); + size_t tmpsize=0; + for (size_t i=0; i=VLEN::val) ? VLEN::val : 1); + if (sz>tmpsize) tmpsize=sz; + } + return arr(tmpsize*elemsize); + } + +template void copy_input(const multi_iter &it, + const cndarr> &src, cmplx> *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, vtype_t *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, T *POCKETFFT_RESTRICT dst) + { + if (dst == &src[it.iofs(0)]) return; // in-place + for (size_t i=0; i void copy_output(const multi_iter &it, + const cmplx> *POCKETFFT_RESTRICT src, ndarr> &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + if (src == &dst[it.oofs(0)]) return; // in-place + for (size_t i=0; i struct add_vec { using type = vtype_t; }; +template struct add_vec> + { using type = cmplx>; }; +template using add_vec_t = typename add_vec::type; + +template +POCKETFFT_NOINLINE void general_nd(const cndarr &in, ndarr &out, + const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, + const bool allow_inplace=true) + { + std::shared_ptr plan; + + for (size_t iax=0; iaxlength())) + plan = get_plan(len); + + threading::thread_map( + util::thread_count(nthreads, in.shape(), axes[iax], VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + const auto &tin(iax==0? in : out); + multi_iter it(tin, out, axes[iax]); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + exec(it, tin, out, tdatav, *plan, fct); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto buf = allow_inplace && it.stride_out() == sizeof(T) ? + &out[it.oofs(0)] : reinterpret_cast(storage.data()); + exec(it, tin, out, buf, *plan, fct); + } + }); // end of parallel region + fct = T0(1); // factor has been applied, use 1 for remaining axes + } + } + +struct ExecC2C + { + bool forward; + + template void operator () ( + const multi_iter &it, const cndarr> &in, + ndarr> &out, T * buf, const pocketfft_c &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, forward); + copy_output(it, buf, out); + } + }; + +template void copy_hartley(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t j=0; j void copy_hartley(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + dst[it.oofs(0)] = src[0]; + size_t i=1, i1=1, i2=it.length_out()-1; + for (i=1; i void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, + T * buf, const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, true); + copy_hartley(it, buf, out); + } + }; + +struct ExecDcst + { + bool ortho; + int type; + bool cosine; + + template + void operator () (const multi_iter &it, const cndarr &in, + ndarr &out, T * buf, const Tplan &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, ortho, type, cosine); + copy_output(it, buf, out); + } + }; + +template POCKETFFT_NOINLINE void general_r2c( + const cndarr &in, ndarr> &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(in.shape(axis)); + size_t len=in.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + copy_input(it, in, tdatav); + plan->exec(tdatav, fct, true); + for (size_t j=0; j0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + copy_input(it, in, tdata); + plan->exec(tdata, fct, true); + out[it.oofs(0)].Set(tdata[0]); + size_t i=1, ii=1; + if (forward) + for (; i POCKETFFT_NOINLINE void general_c2r( + const cndarr> &in, ndarr &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(out.shape(axis)); + size_t len=out.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(out.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + for (size_t j=0; jexec(tdatav, fct, false); + copy_output(it, tdatav, out); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + tdata[0]=in[it.iofs(0)].r; + { + size_t i=1, ii=1; + if (forward) + for (; iexec(tdata, fct, false); + copy_output(it, tdata, out); + } + }); // end of parallel region + } + +struct ExecR2R + { + bool r2h, forward; + + template void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, T * buf, + const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + if ((!r2h) && forward) + for (size_t i=2; i void c2c(const shape_t &shape, const stride_t &stride_in, + const stride_t &stride_out, const shape_t &axes, bool forward, + const std::complex *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr> ain(data_in, shape, stride_in); + ndarr> aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); + } + +template void dct(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, true}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void dst(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, false}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axis); + cndarr ain(data_in, shape_in, stride_in); + shape_t shape_out(shape_in); + shape_out[axis] = shape_in[axis]/2 + 1; + ndarr> aout(data_out, shape_out, stride_out); + general_r2c(ain, aout, axis, forward, fct, nthreads); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axes); + r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out, + fct, nthreads); + if (axes.size()==1) return; + + shape_t shape_out(shape_in); + shape_out[axes.back()] = shape_in[axes.back()]/2 + 1; + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out, + T(1), nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + util::sanity_check(shape_out, stride_in, stride_out, false, axis); + shape_t shape_in(shape_out); + shape_in[axis] = shape_out[axis]/2 + 1; + cndarr> ain(data_in, shape_in, stride_in); + ndarr aout(data_out, shape_out, stride_out); + general_c2r(ain, aout, axis, forward, fct, nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + if (axes.size()==1) + return c2r(shape_out, stride_in, stride_out, axes[0], forward, + data_in, data_out, fct, nthreads); + util::sanity_check(shape_out, stride_in, stride_out, false, axes); + auto shape_in = shape_out; + shape_in[axes.back()] = shape_out[axes.back()]/2 + 1; + auto nval = util::prod(shape_in); + stride_t stride_inter(shape_in.size()); + stride_inter.back() = sizeof(cmplx); + for (int i=int(shape_in.size())-2; i>=0; --i) + stride_inter[size_t(i)] = + stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]); + arr> tmp(nval); + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(), + T(1), nthreads); + c2r(shape_out, stride_inter, stride_out, axes.back(), forward, + tmp.data(), data_out, fct, nthreads); + } + +template void r2r_fftpack(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, + ExecR2R{real2hermitian, forward}); + } + +template void r2r_separable_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecHartley{}, + false); + } + +template void r2r_genuine_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + if (axes.size()==1) + return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in, + data_out, fct, nthreads); + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + shape_t tshp(shape); + tshp[axes.back()] = tshp[axes.back()]/2+1; + arr> tdata(util::prod(tshp)); + stride_t tstride(shape.size()); + tstride.back()=sizeof(std::complex); + for (size_t i=tstride.size()-1; i>0; --i) + tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]); + r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads); + cndarr> atmp(tdata.data(), tshp, tstride); + ndarr aout(data_out, shape, stride_out); + simple_iter iin(atmp); + rev_iter iout(aout, axes); + while(iin.remaining()>0) + { + auto v = atmp[iin.ofs()]; + aout[iout.ofs()] = v.r+v.i; + aout[iout.rev_ofs()] = v.r-v.i; + iin.advance(); iout.advance(); + } + } + +} // namespace detail + +using detail::FORWARD; +using detail::BACKWARD; +using detail::shape_t; +using detail::stride_t; +using detail::c2c; +using detail::c2r; +using detail::r2c; +using detail::r2r_fftpack; +using detail::r2r_separable_hartley; +using detail::r2r_genuine_hartley; +using detail::dct; +using detail::dst; + +} // namespace pocketfft + +#undef POCKETFFT_NOINLINE +#undef POCKETFFT_RESTRICT + +#endif // POCKETFFT_HDRONLY_H diff --git a/mlx/allocator.h b/mlx/allocator.h new file mode 100644 index 000000000..29421d47c --- /dev/null +++ b/mlx/allocator.h @@ -0,0 +1,64 @@ +#pragma once + +#include + +namespace mlx::core::allocator { + +// Simple wrapper around buffer pointers +// WARNING: Only Buffer objects constructed from and those that wrap +// raw pointers from mlx::allocator are supported. +class Buffer { + private: + void* ptr_; + + public: + Buffer(void* ptr) : ptr_(ptr){}; + + // Get the raw data pointer from the buffer + void* raw_ptr(); + + // Get the buffer pointer from the buffer + const void* ptr() const { + return ptr_; + }; + void* ptr() { + return ptr_; + }; +}; + +Buffer malloc(size_t size); + +void free(Buffer buffer); + +// Wait for running tasks to finish and free up memory +// if allocation fails +Buffer malloc_or_wait(size_t size); + +class Allocator { + /** Abstract base clase for a memory allocator. */ + public: + virtual Buffer malloc(size_t size) = 0; + virtual void free(Buffer buffer) = 0; + + Allocator() = default; + Allocator(const Allocator& other) = delete; + Allocator(Allocator&& other) = delete; + Allocator& operator=(const Allocator& other) = delete; + Allocator& operator=(Allocator&& other) = delete; + virtual ~Allocator() = default; +}; + +Allocator& allocator(); + +class CommonAllocator : public Allocator { + /** A general CPU allocator. */ + public: + virtual Buffer malloc(size_t size) override; + virtual void free(Buffer buffer) override; + + private: + CommonAllocator() = default; + friend Allocator& allocator(); +}; + +} // namespace mlx::core::allocator diff --git a/mlx/backend/accelerate/conv.cpp b/mlx/backend/accelerate/conv.cpp new file mode 100644 index 000000000..432e60c59 --- /dev/null +++ b/mlx/backend/accelerate/conv.cpp @@ -0,0 +1,18 @@ +#include + +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +void Convolution::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); + + // TODO: Add accelerate based optimizations for CPU conv +} + +} // namespace mlx::core diff --git a/mlx/backend/accelerate/utils.h b/mlx/backend/accelerate/utils.h new file mode 100644 index 000000000..7c2f1ae65 --- /dev/null +++ b/mlx/backend/accelerate/utils.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include "mlx/dtype.h" + +namespace mlx::core { + +BNNSDataType to_bnns_dtype(Dtype mlx_dtype) { + uint32_t size_bits = size_of(mlx_dtype) * 8; + switch (kindof(mlx_dtype)) { + case Dtype::Kind::b: + return BNNSDataTypeBoolean; + case Dtype::Kind::u: + return BNNSDataType(BNNSDataTypeUIntBit | size_bits); + case Dtype::Kind::i: + return BNNSDataType(BNNSDataTypeIntBit | size_bits); + case Dtype::Kind::f: + return BNNSDataType(BNNSDataTypeFloatBit | size_bits); + case Dtype::Kind::V: + return BNNSDataTypeBFloat16; + case Dtype::Kind::c: + throw std::invalid_argument("BNNS does not support complex types"); + } +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h new file mode 100644 index 000000000..dc46e91b4 --- /dev/null +++ b/mlx/backend/common/copy.h @@ -0,0 +1,27 @@ +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +enum class CopyType { + // Copy a raw scalar input into the full contiguous output + Scalar, + + // Copy the raw input buffer contiguously into a raw output buffer of the same + // size + Vector, + + // Copy the full virtual input to the full contiguous output + General, + + // Copy the full virtual input to the full virtual output. We assume the + // input and output have the same shape. + GeneralGeneral +}; + +void copy(const array& src, array& dst, CopyType ctype); +void copy_inplace(const array& src, array& dst, CopyType ctype); + +} // namespace mlx::core diff --git a/mlx/backend/common/sort.cpp b/mlx/backend/common/sort.cpp new file mode 100644 index 000000000..8f0295591 --- /dev/null +++ b/mlx/backend/common/sort.cpp @@ -0,0 +1,394 @@ +#include +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" + +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +struct StridedIterator { + using iterator_category = std::random_access_iterator_tag; + using difference_type = IdxT; + using value_type = T; + using reference = value_type&; + using pointer = value_type*; + + // Constructors + StridedIterator() = default; + + explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0) + : ptr_(ptr + offset * stride), stride_(stride) {} + + explicit StridedIterator(array& arr, int axis, difference_type offset = 0) + : StridedIterator(arr.data(), arr.strides()[axis], offset) {} + + // Accessors + reference operator*() const { + return ptr_[0]; + } + + reference operator[](difference_type idx) const { + return ptr_[idx * stride_]; + } + + // Comparisons + bool operator==(const StridedIterator& other) const { + return ptr_ == other.ptr_ && stride_ == other.stride_; + } + + bool operator!=(const StridedIterator& other) const { + return ptr_ != other.ptr_; + } + + bool operator<(const StridedIterator& other) const { + return ptr_ < other.ptr_; + } + + bool operator>(const StridedIterator& other) const { + return ptr_ > other.ptr_; + } + + bool operator<=(const StridedIterator& other) const { + return ptr_ <= other.ptr_; + } + + bool operator>=(const StridedIterator& other) const { + return ptr_ >= other.ptr_; + } + + difference_type operator-(const StridedIterator& other) const { + return (ptr_ - other.ptr_) / stride_; + } + + // Moving + StridedIterator& operator++() { + ptr_ += stride_; + return *this; + } + + StridedIterator& operator--() { + ptr_ -= stride_; + return *this; + } + + StridedIterator& operator+=(difference_type diff) { + ptr_ += diff * stride_; + return *this; + } + + StridedIterator& operator-=(difference_type diff) { + ptr_ -= diff * stride_; + return *this; + } + + StridedIterator operator+(difference_type diff) { + return StridedIterator(ptr_, stride_, diff); + } + + StridedIterator operator-(difference_type diff) { + return StridedIterator(ptr_, stride_, -diff); + } + + private: + size_t stride_; + T* ptr_; +}; + +template +void sort(const array& in, array& out, int axis) { + // Copy input to output + CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; + copy(in, out, ctype); + + // Get axis, shape and stride info + axis = axis < 0 ? axis + in.ndim() : axis; + size_t n_rows = in.size() / in.shape(axis); + + auto remaining_shape = in.shape(); + remaining_shape.erase(remaining_shape.begin() + axis); + + auto remaining_strides = in.strides(); + remaining_strides.erase(remaining_strides.begin() + axis); + + size_t axis_stride = in.strides()[axis]; + int axis_size = in.shape(axis); + + // Perform sorting in place + for (int i = 0; i < n_rows; i++) { + size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); + T* data_ptr = out.data() + loc; + + StridedIterator st(data_ptr, axis_stride, 0); + StridedIterator ed(data_ptr, axis_stride, axis_size); + + std::stable_sort(st, ed); + } +} + +template +void argsort(const array& in, array& out, int axis) { + // Allocate output + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + // Get axis, shape and stride info + axis = axis < 0 ? axis + in.ndim() : axis; + size_t n_rows = in.size() / in.shape(axis); + + auto remaining_shape = in.shape(); + remaining_shape.erase(remaining_shape.begin() + axis); + + auto remaining_strides = in.strides(); + remaining_strides.erase(remaining_strides.begin() + axis); + + size_t axis_stride = in.strides()[axis]; + int axis_size = in.shape(axis); + + // Perform sorting + for (int i = 0; i < n_rows; i++) { + size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); + const T* data_ptr = in.data() + loc; + IdxT* idx_ptr = out.data() + loc; + + StridedIterator st_(idx_ptr, axis_stride, 0); + StridedIterator ed_(idx_ptr, axis_stride, axis_size); + + // Initialize with iota + std::iota(st_, ed_, IdxT(0)); + + // Sort according to vals + StridedIterator st(idx_ptr, axis_stride, 0); + StridedIterator ed(idx_ptr, axis_stride, axis_size); + + std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) { + auto v1 = data_ptr[a * axis_stride]; + auto v2 = data_ptr[b * axis_stride]; + return v1 < v2 || (v1 == v2 && a < b); + }); + } +} + +template +void partition(const array& in, array& out, int axis, int kth) { + // Copy input to output + CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; + copy(in, out, ctype); + + // Get axis, shape and stride info + axis = axis < 0 ? axis + in.ndim() : axis; + size_t n_rows = in.size() / in.shape(axis); + + auto remaining_shape = in.shape(); + remaining_shape.erase(remaining_shape.begin() + axis); + + auto remaining_strides = in.strides(); + remaining_strides.erase(remaining_strides.begin() + axis); + + size_t axis_stride = in.strides()[axis]; + int axis_size = in.shape(axis); + + kth = kth < 0 ? kth + axis_size : kth; + + // Perform partition in place + for (int i = 0; i < n_rows; i++) { + size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); + T* data_ptr = out.data() + loc; + + StridedIterator st(data_ptr, axis_stride, 0); + StridedIterator md(data_ptr, axis_stride, kth); + StridedIterator ed(data_ptr, axis_stride, axis_size); + + std::nth_element(st, md, ed); + } +} + +template +void argpartition(const array& in, array& out, int axis, int kth) { + // Allocate output + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + // Get axis, shape and stride info + axis = axis < 0 ? axis + in.ndim() : axis; + size_t n_rows = in.size() / in.shape(axis); + + auto remaining_shape = in.shape(); + remaining_shape.erase(remaining_shape.begin() + axis); + + auto remaining_strides = in.strides(); + remaining_strides.erase(remaining_strides.begin() + axis); + + size_t axis_stride = in.strides()[axis]; + int axis_size = in.shape(axis); + + kth = kth < 0 ? kth + axis_size : kth; + + // Perform partition + for (int i = 0; i < n_rows; i++) { + size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); + const T* data_ptr = in.data() + loc; + IdxT* idx_ptr = out.data() + loc; + + StridedIterator st_(idx_ptr, axis_stride, 0); + StridedIterator ed_(idx_ptr, axis_stride, axis_size); + + // Initialize with iota + std::iota(st_, ed_, IdxT(0)); + + // Sort according to vals + StridedIterator st(idx_ptr, axis_stride, 0); + StridedIterator md(idx_ptr, axis_stride, kth); + StridedIterator ed(idx_ptr, axis_stride, axis_size); + + std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) { + auto v1 = data_ptr[a * axis_stride]; + auto v2 = data_ptr[b * axis_stride]; + return v1 < v2 || (v1 == v2 && a < b); + }); + } +} + +} // namespace + +void ArgSort::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + + switch (in.dtype()) { + case bool_: + return argsort(in, out, axis_); + case uint8: + return argsort(in, out, axis_); + case uint16: + return argsort(in, out, axis_); + case uint32: + return argsort(in, out, axis_); + case uint64: + return argsort(in, out, axis_); + case int8: + return argsort(in, out, axis_); + case int16: + return argsort(in, out, axis_); + case int32: + return argsort(in, out, axis_); + case int64: + return argsort(in, out, axis_); + case float32: + return argsort(in, out, axis_); + case float16: + return argsort(in, out, axis_); + case bfloat16: + return argsort(in, out, axis_); + case complex64: + return argsort(in, out, axis_); + } +} + +void Sort::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + + switch (in.dtype()) { + case bool_: + return sort(in, out, axis_); + case uint8: + return sort(in, out, axis_); + case uint16: + return sort(in, out, axis_); + case uint32: + return sort(in, out, axis_); + case uint64: + return sort(in, out, axis_); + case int8: + return sort(in, out, axis_); + case int16: + return sort(in, out, axis_); + case int32: + return sort(in, out, axis_); + case int64: + return sort(in, out, axis_); + case float32: + return sort(in, out, axis_); + case float16: + return sort(in, out, axis_); + case bfloat16: + return sort(in, out, axis_); + case complex64: + return sort(in, out, axis_); + } +} + +void ArgPartition::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + + switch (in.dtype()) { + case bool_: + return argpartition(in, out, axis_, kth_); + case uint8: + return argpartition(in, out, axis_, kth_); + case uint16: + return argpartition(in, out, axis_, kth_); + case uint32: + return argpartition(in, out, axis_, kth_); + case uint64: + return argpartition(in, out, axis_, kth_); + case int8: + return argpartition(in, out, axis_, kth_); + case int16: + return argpartition(in, out, axis_, kth_); + case int32: + return argpartition(in, out, axis_, kth_); + case int64: + return argpartition(in, out, axis_, kth_); + case float32: + return argpartition(in, out, axis_, kth_); + case float16: + return argpartition(in, out, axis_, kth_); + case bfloat16: + return argpartition(in, out, axis_, kth_); + case complex64: + return argpartition(in, out, axis_, kth_); + } +} + +void Partition::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + + switch (in.dtype()) { + case bool_: + return partition(in, out, axis_, kth_); + case uint8: + return partition(in, out, axis_, kth_); + case uint16: + return partition(in, out, axis_, kth_); + case uint32: + return partition(in, out, axis_, kth_); + case uint64: + return partition(in, out, axis_, kth_); + case int8: + return partition(in, out, axis_, kth_); + case int16: + return partition(in, out, axis_, kth_); + case int32: + return partition(in, out, axis_, kth_); + case int64: + return partition(in, out, axis_, kth_); + case float32: + return partition(in, out, axis_, kth_); + case float16: + return partition(in, out, axis_, kth_); + case bfloat16: + return partition(in, out, axis_, kth_); + case complex64: + return partition(in, out, axis_, kth_); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp new file mode 100644 index 000000000..6a133329c --- /dev/null +++ b/mlx/backend/metal/device.cpp @@ -0,0 +1,257 @@ +#include +#include +#include +#include + +#define NS_PRIVATE_IMPLEMENTATION +#define CA_PRIVATE_IMPLEMENTATION +#define MTL_PRIVATE_IMPLEMENTATION + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/backend/metal/mps/gemm.h" + +namespace fs = std::filesystem; + +namespace mlx::core::metal { + +static Device metal_device_; + +namespace { + +// TODO nicer way to set this or possibly expose as an environment variable +static constexpr int MAX_BUFFERS_PER_QUEUE = 12; + +static constexpr const char* default_mtllib_path = METAL_PATH; + +auto load_device() { + MTL::Device* device = MTL::CreateSystemDefaultDevice(); + if (!device) { + throw std::runtime_error("Failed to load device"); + } + return device; +} + +std::pair load_library_from_path( + MTL::Device* device, + const char* path) { + auto library = NS::String::string(path, NS::UTF8StringEncoding); + NS::Error* error; + auto lib = device->newLibrary(library, &error); + + return std::make_pair(lib, error); +} + +MTL::Library* load_library( + MTL::Device* device, + const std::string& lib_name = "mlx", + const char* lib_path = default_mtllib_path) { + // Firstly, search for the metallib in the same path as this binary + std::string first_path = get_colocated_mtllib_path(lib_name); + if (first_path.size() != 0) { + auto [lib, error] = load_library_from_path(device, first_path.c_str()); + if (lib) { + return lib; + } + } + + // Couldn't find it so let's load it from default_mtllib_path + { + auto [lib, error] = load_library_from_path(device, lib_path); + if (!lib) { + std::ostringstream msg; + msg << error->localizedDescription()->utf8String() << "\n" + << "Failed to load device library from <" << lib_path << ">" + << " or <" << first_path << ">."; + throw std::runtime_error(msg.str()); + } + return lib; + } +} + +} // namespace + +Device::Device() + : pool_(NS::AutoreleasePool::alloc()->init()), + device_(load_device()), + library_map_({{"mlx", load_library(device_)}}) {} + +Device::~Device() { + for (auto& q : queue_map_) { + q.second->release(); + } + for (auto& k : kernel_map_) { + k.second->release(); + } + for (auto& l : library_map_) { + l.second->release(); + } + device_->release(); + pool_->release(); +} + +void Device::new_queue(int index) { + // Multiple threads can ask the device for queues + // We lock this as a critical section for safety + const std::lock_guard lock(mtx_); + auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); + if (!q) { + throw std::runtime_error( + "[metal::Device] Failed to make new command queue."); + } + queue_map_.insert({index, q}); +} + +int Device::get_command_buffer_ops(int index) { + auto bit = buffer_map_.find(index); + return bit->second.first; +} + +void Device::increment_command_buffer_ops(int index) { + auto bit = buffer_map_.find(index); + bit->second.first++; +} + +MTL::CommandBuffer* Device::get_command_buffer(int index) { + auto bit = buffer_map_.find(index); + return (bit == buffer_map_.end()) ? nullptr : bit->second.second; +} + +MTL::CommandBuffer* Device::new_command_buffer(int index) { + auto qit = queue_map_.find(index); + if (qit == queue_map_.end()) { + throw std::runtime_error( + "[metal::Device] Attempting to get command buffer for invalid queue."); + } + + auto cb = qit->second->commandBufferWithUnretainedReferences(); + + if (!cb) { + throw std::runtime_error( + "[metal::Device] Unable to create new command buffer"); + } + + // Increment ref count so the buffer is not garbage collected + cb->retain(); + + return buffer_map_.insert({index, {0, cb}}).first->second.second; +} + +void Device::commit_command_buffer(int index) { + auto bit = buffer_map_.find(index); + bit->second.second->commit(); + bit->second.second->release(); + buffer_map_.erase(bit); +} + +void Device::end_encoding(int index) { + auto eit = encoder_map_.find(index); + if (eit != encoder_map_.end()) { + eit->second->endEncoding(); + eit->second->release(); + encoder_map_.erase(eit); + } +} + +MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) { + auto eit = encoder_map_.find(index); + if (eit == encoder_map_.end()) { + auto cb = get_command_buffer(index); + auto compute_encoder = cb->computeCommandEncoder(); + // Increment ref count so the buffer is not garbage collected + compute_encoder->retain(); + eit = encoder_map_.insert({index, compute_encoder}).first; + } + return eit->second; +} + +MTL::ArgumentEncoder* Device::argument_encoder( + const std::vector& arg_descs) const { + // NB array here is already autoreleased but the returned argument + // encoder is owned by the caller and must be released/autoreleased + NS::Array* arg_desc_arr = NS::Array::array( + reinterpret_cast(arg_descs.data()), arg_descs.size()); + return device_->newArgumentEncoder(arg_desc_arr); +} + +void Device::register_library( + const std::string& lib_name, + const std::string& lib_path) { + if (auto it = library_map_.find(lib_name); it == library_map_.end()) { + auto new_lib = load_library(device_, lib_name, lib_path.c_str()); + library_map_.insert({lib_name, new_lib}); + } +} + +void Device::register_library( + const std::string& lib_name, + const std::function& lib_path_func) { + if (auto it = library_map_.find(lib_name); it == library_map_.end()) { + std::string new_lib_path = lib_path_func(lib_name); + auto new_lib = load_library(device_, lib_name, new_lib_path.c_str()); + library_map_.insert({lib_name, new_lib}); + } +} + +MTL::ComputePipelineState* Device::get_kernel( + const std::string& name, + const std::string& lib_name /* = "mlx" */) { + // Look for cached kernel + if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { + return it->second; + } + + // Prepare new kernel + + // Search for cached metal lib + MTL::Library* mtl_lib; + if (auto it = library_map_.find(name); it != library_map_.end()) { + mtl_lib = it->second; + } else { // Look for metallib alongside library + register_library(lib_name); + mtl_lib = library_map_[lib_name]; + } + + // Pull kernel from library + auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding); + auto mtl_function = mtl_lib->newFunction(ns_name); + + // Compile kernel to compute pipeline + NS::Error* error = nullptr; + MTL::ComputePipelineState* kernel; + if (mtl_function) { + kernel = device_->newComputePipelineState(mtl_function, &error); + mtl_function->release(); + } + if (!mtl_function || !kernel) { + std::ostringstream msg; + msg << "[metal::Device] Unable to load kernel " << name << "\n"; + if (error) { + msg << error->localizedDescription()->utf8String() << "\n"; + } + throw std::runtime_error(msg.str()); + } + + // Add kernel to cache + kernel_map_.insert({name, kernel}); + return kernel; +} + +Device& device(mlx::core::Device) { + return metal_device_; +} + +NS::AutoreleasePool*& thread_autorelease_pool() { + static thread_local NS::AutoreleasePool* p = + NS::AutoreleasePool::alloc()->init(); + return p; +} + +void new_stream(Stream stream) { + thread_autorelease_pool(); + if (stream.device == mlx::core::Device::gpu) { + device(stream.device).new_queue(stream.index); + } +} + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp new file mode 100644 index 000000000..0cf34e5d8 --- /dev/null +++ b/mlx/backend/metal/fft.cpp @@ -0,0 +1,10 @@ +#include "mlx/primitives.h" + +namespace mlx::core { + +void FFT::eval_gpu(const std::vector& inputs, array& out) { + auto& in = inputs[0]; + throw std::runtime_error("[FFT] NYI for Metal backend."); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt new file mode 100644 index 000000000..8fc5eac30 --- /dev/null +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -0,0 +1,83 @@ +set( + HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/bf16.h + ${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h + ${CMAKE_CURRENT_SOURCE_DIR}/complex.h + ${CMAKE_CURRENT_SOURCE_DIR}/defines.h + ${CMAKE_CURRENT_SOURCE_DIR}/erf.h + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h + ${CMAKE_CURRENT_SOURCE_DIR}/utils.h +) + +set( + KERNELS + "arange" + "arg_reduce" + "binary" + "conv" + "copy" + "gemm" + "gemv" + "random" + "reduce" + "scan" + "softmax" + "sort" + "unary" + "indexing" +) + +function(build_kernel KERNEL) + set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) + set(HEADERS_PADDED ${HEADERS}) + if(${KERNEL} STREQUAL "gemm") + set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h) + endif() + if(${KERNEL} STREQUAL "conv") + set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h) + endif() + add_custom_command( + COMMAND xcrun -sdk macosx metal -Wall -Wextra + -fno-fast-math + -c ${SRCFILE} + -I${PROJECT_SOURCE_DIR} + -o ${KERNEL}.air + DEPENDS ${SRCFILE} ${HEADERS_PADDED} + OUTPUT ${KERNEL}.air + COMMENT "Building ${KERNEL}.air" + VERBATIM + ) +endfunction(build_kernel) + +foreach(KERNEL ${KERNELS}) + build_kernel(${KERNEL}) + set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR}) +endforeach() + +add_custom_command( + OUTPUT ${MLX_METAL_PATH}/mlx.metallib + COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib + DEPENDS ${KERNEL_AIR} + COMMENT "Building mlx.metallib" + VERBATIM +) + +add_custom_target( + mlx-metallib + DEPENDS + ${MLX_METAL_PATH}/mlx.metallib +) + +add_dependencies( + mlx + mlx-metallib +) + +# Install metallib +include(GNUInstallDirs) + +install( + FILES ${MLX_METAL_PATH}/mlx.metallib + DESTINATION ${CMAKE_INSTALL_LIBDIR} + COMPONENT metallib +) diff --git a/mlx/backend/metal/kernels/conv_params.h b/mlx/backend/metal/kernels/conv_params.h new file mode 100644 index 000000000..b3748ef46 --- /dev/null +++ b/mlx/backend/metal/kernels/conv_params.h @@ -0,0 +1,17 @@ +#pragma once + +template +struct MLXConvParams { + const int N; // Batch size + const int C; // In channels + const int O; // Out channels + const int iS[NDIM]; // Input spatial dim + const int wS[NDIM]; // Weight spatial dim + const int oS[NDIM]; // Output spatial dim + const int str[NDIM]; // Kernel strides + const int pad[NDIM]; // Input padding + const int dil[NDIM]; // Kernel dilation + const size_t in_strides[NDIM + 2]; // In strides + const size_t wt_strides[NDIM + 2]; // Wt strides + const size_t out_strides[NDIM + 2]; // Out strides +}; diff --git a/mlx/backend/metal/kernels/indexing.metal b/mlx/backend/metal/kernels/indexing.metal new file mode 100644 index 000000000..fc0d9ee72 --- /dev/null +++ b/mlx/backend/metal/kernels/indexing.metal @@ -0,0 +1,253 @@ +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/reduce.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Gather kernel +///////////////////////////////////////////////////////////////////// + +template +struct Indices { + const array buffers [[id(0)]]; + device int* shapes [[id(NIDX + 1)]]; + device size_t* strides [[id(NIDX + 2)]]; + const int ndim [[id(NIDX + 3)]]; +}; + +template +inline size_t offset_neg_idx(IdxT idx, size_t size) { + return (idx < 0) ? idx + size : idx; +} + +template <> +inline size_t offset_neg_idx(bool idx, size_t) { + return idx; +} + +template <> +inline size_t offset_neg_idx(uint32_t idx, size_t) { + return idx; +} + +template +[[kernel]] void gather( + const device T *src [[buffer(0)]], + const device Indices& indices [[buffer(1)]], + device T *out [[buffer(2)]], + const device int *src_shape [[buffer(3)]], + const device size_t *src_strides [[buffer(4)]], + const device size_t& src_ndim [[buffer(5)]], + const device int *slice_sizes [[buffer(6)]], + const device size_t& slice_size [[buffer(7)]], + const device int *axes [[buffer(8)]], + uint gid [[thread_position_in_grid]]) { + + auto ind_idx = gid / slice_size; + auto ind_offset = gid % slice_size; + + size_t src_idx = 0; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx( + indices.buffers[i][idx_loc], src_shape[ax]); + src_idx += idx_val * src_strides[ax]; + } + + auto src_offset = elem_to_loc( + ind_offset, slice_sizes, src_strides, src_ndim); + out[gid] = src[src_idx + src_offset]; +} + +#define instantiate_gather4(name, src_type, ind_type, nindex) \ +template [[host_name("gather" name "_" #nindex)]] \ +[[kernel]] void gather( \ + const device src_type *src [[buffer(0)]], \ + const device Indices& indices [[buffer(1)]], \ + device src_type *out [[buffer(2)]], \ + const device int *src_shape [[buffer(3)]], \ + const device size_t *src_strides [[buffer(4)]], \ + const device size_t& src_ndim [[buffer(5)]], \ + const device int *slice_sizes [[buffer(6)]], \ + const device size_t& slice_size [[buffer(7)]], \ + const device int* axes [[buffer(8)]], \ + uint gid [[thread_position_in_grid]]); + +// Special for case NIDX=0 +instantiate_gather4("bool_", bool, bool, 0) +instantiate_gather4("uint8", uint8_t, bool, 0) +instantiate_gather4("uint16", uint16_t, bool, 0) +instantiate_gather4("uint32", uint32_t, bool, 0) +instantiate_gather4("uint64", uint64_t, bool, 0) +instantiate_gather4("int8", int8_t, bool, 0) +instantiate_gather4("int16", int16_t, bool, 0) +instantiate_gather4("int32", int32_t, bool, 0) +instantiate_gather4("int64", int64_t, bool, 0) +instantiate_gather4("float16", half, bool, 0) +instantiate_gather4("float32", float, bool, 0) +instantiate_gather4("bfloat16", bfloat16_t, bool, 0) + +#define instantiate_gather3(name, src_type, ind_type) \ + instantiate_gather4(name, src_type, ind_type, 1) \ + instantiate_gather4(name, src_type, ind_type, 2) \ + instantiate_gather4(name, src_type, ind_type, 3) \ + instantiate_gather4(name, src_type, ind_type, 4) \ + instantiate_gather4(name, src_type, ind_type, 5) \ + instantiate_gather4(name, src_type, ind_type, 6) \ + instantiate_gather4(name, src_type, ind_type, 7) \ + instantiate_gather4(name, src_type, ind_type, 8) \ + instantiate_gather4(name, src_type, ind_type, 9) \ + instantiate_gather4(name, src_type, ind_type, 10) + +#define instantiate_gather(name, src_type) \ + instantiate_gather3(#name "bool_", src_type, bool) \ + instantiate_gather3(#name "uint8", src_type, uint8_t) \ + instantiate_gather3(#name "uint16", src_type, uint16_t) \ + instantiate_gather3(#name "uint32", src_type, uint32_t) \ + instantiate_gather3(#name "uint64", src_type, uint64_t) \ + instantiate_gather3(#name "int8", src_type, int8_t) \ + instantiate_gather3(#name "int16", src_type, int16_t) \ + instantiate_gather3(#name "int32", src_type, int32_t) \ + instantiate_gather3(#name "int64", src_type, int64_t) + +instantiate_gather(bool_, bool) +instantiate_gather(uint8, uint8_t) +instantiate_gather(uint16, uint16_t) +instantiate_gather(uint32, uint32_t) +instantiate_gather(uint64, uint64_t) +instantiate_gather(int8, int8_t) +instantiate_gather(int16, int16_t) +instantiate_gather(int32, int32_t) +instantiate_gather(int64, int64_t) +instantiate_gather(float16, half) +instantiate_gather(float32, float) +instantiate_gather(bfloat16, bfloat16_t) + +///////////////////////////////////////////////////////////////////// +// Scatter kernel +///////////////////////////////////////////////////////////////////// + +template +[[kernel]] void scatter( + const device Indices& indices [[buffer(0)]], + const device T *updates [[buffer(1)]], + device mlx_atomic *out [[buffer(2)]], + const device int *upd_shape [[buffer(3)]], + const device size_t *upd_strides [[buffer(4)]], + const device size_t& upd_ndim [[buffer(5)]], + const device size_t& upd_size [[buffer(6)]], + const device int *out_shape [[buffer(7)]], + const device size_t *out_strides [[buffer(8)]], + const device size_t& out_ndim [[buffer(9)]], + const device int* axes [[buffer(10)]], + uint gid [[thread_position_in_grid]]) { + + Op op; + auto ind_idx = gid / upd_size; + auto ind_offset = gid % upd_size; + + size_t out_idx = 0; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx( + indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += idx_val * out_strides[ax]; + } + + auto out_offset = elem_to_loc( + ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim); + + op.atomic_update(out + out_idx + out_offset, updates[upd_idx]); +} + +#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \ +template [[host_name("scatter" name "_" #nindex)]] \ +[[kernel]] void scatter( \ + const device Indices& indices [[buffer(0)]], \ + const device type *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const device int *upd_shape [[buffer(3)]], \ + const device size_t *upd_strides [[buffer(4)]], \ + const device size_t& upd_ndim [[buffer(5)]], \ + const device size_t& upd_size [[buffer(6)]], \ + const device int *out_shape [[buffer(7)]], \ + const device size_t *out_strides [[buffer(8)]], \ + const device size_t& out_ndim [[buffer(9)]], \ + const device int* axes [[buffer(10)]], \ + uint gid [[thread_position_in_grid]]); + +// Special case NINDEX=0 +#define instantiate_scatter_nd0(name, type) \ + instantiate_scatter4(#name "none", type, bool, None, 0) \ + instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ + instantiate_scatter4(#name "_prod", type, bool, Prod, 0) \ + instantiate_scatter4(#name "_max", type, bool, Max, 0) \ + instantiate_scatter4(#name "_min", type, bool, Min, 0) + +#define instantiate_scatter3(name, type, ind_type, op_type) \ + instantiate_scatter4(name, type, ind_type, op_type, 1) \ + instantiate_scatter4(name, type, ind_type, op_type, 2) \ + instantiate_scatter4(name, type, ind_type, op_type, 3) \ + instantiate_scatter4(name, type, ind_type, op_type, 4) \ + instantiate_scatter4(name, type, ind_type, op_type, 5) \ + instantiate_scatter4(name, type, ind_type, op_type, 6) \ + instantiate_scatter4(name, type, ind_type, op_type, 7) \ + instantiate_scatter4(name, type, ind_type, op_type, 8) \ + instantiate_scatter4(name, type, ind_type, op_type, 9) \ + instantiate_scatter4(name, type, ind_type, op_type, 10) + +#define instantiate_scatter2(name, type, ind_type) \ + instantiate_scatter3(name "_none", type, ind_type, None) \ + instantiate_scatter3(name "_sum", type, ind_type, Sum) \ + instantiate_scatter3(name "_prod", type, ind_type, Prod) \ + instantiate_scatter3(name "_max", type, ind_type, Max) \ + instantiate_scatter3(name "_min", type, ind_type, Min) + +#define instantiate_scatter(name, type) \ + instantiate_scatter2(#name "bool_", type, bool) \ + instantiate_scatter2(#name "uint8", type, uint8_t) \ + instantiate_scatter2(#name "uint16", type, uint16_t) \ + instantiate_scatter2(#name "uint32", type, uint32_t) \ + instantiate_scatter2(#name "uint64", type, uint64_t) \ + instantiate_scatter2(#name "int8", type, int8_t) \ + instantiate_scatter2(#name "int16", type, int16_t) \ + instantiate_scatter2(#name "int32", type, int32_t) \ + instantiate_scatter2(#name "int64", type, int64_t) + +// TODO uint64 and int64 unsupported +instantiate_scatter_nd0(bool_, bool) +instantiate_scatter_nd0(uint8, uint8_t) +instantiate_scatter_nd0(uint16, uint16_t) +instantiate_scatter_nd0(uint32, uint32_t) +instantiate_scatter_nd0(int8, int8_t) +instantiate_scatter_nd0(int16, int16_t) +instantiate_scatter_nd0(int32, int32_t) +instantiate_scatter_nd0(float16, half) +instantiate_scatter_nd0(float32, float) +instantiate_scatter_nd0(bfloat16, bfloat16_t) + +instantiate_scatter(bool_, bool) +instantiate_scatter(uint8, uint8_t) +instantiate_scatter(uint16, uint16_t) +instantiate_scatter(uint32, uint32_t) +instantiate_scatter(int8, int8_t) +instantiate_scatter(int16, int16_t) +instantiate_scatter(int32, int32_t) +instantiate_scatter(float16, half) +instantiate_scatter(float32, float) +instantiate_scatter(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/kernels/reduce.h b/mlx/backend/metal/kernels/reduce.h new file mode 100644 index 000000000..d670f245f --- /dev/null +++ b/mlx/backend/metal/kernels/reduce.h @@ -0,0 +1,174 @@ +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/atomic.h" +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +union bool4_or_uint { + bool4 b; + unsigned int i; +}; + +struct None { + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_store_explicit(out, val, offset); + } +}; + +struct And { + bool simd_reduce(bool val) { + return simd_all(val); + }; + + static constexpr constant bool init = true; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + int offset = 0) { + if (!val) { + bool4_or_uint update; + update.b = {true, true, true, true}; + update.b[elem_idx] = false; + mlx_atomic_fetch_and_explicit(out, update.i, offset); + } + } + + void atomic_update(device mlx_atomic* out, bool val, int offset = 0) { + if (!val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out &= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a && b; + } +}; + +struct Or { + bool simd_reduce(bool val) { + return simd_any(val); + }; + + static constexpr constant bool init = false; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + int offset = 0) { + if (val) { + bool4_or_uint update; + update.b = {false, false, false, false}; + update.b[elem_idx] = true; + mlx_atomic_fetch_or_explicit(out, update.i, offset); + } + } + + void atomic_update(device mlx_atomic* out, bool val, int offset = 0) { + if (val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out |= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a || b; + } +}; + +template +struct Sum { + template + T simd_reduce(T val) { + return simd_sum(val); + }; + + static constexpr constant U init = U(0); + + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_fetch_add_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a + b; + } +}; + +template +struct Prod { + template + T simd_reduce(T val) { + return simd_product(val); + }; + + static constexpr constant U init = U(1); + + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_fetch_mul_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a * b; + } +}; + +template +struct Min { + template + T simd_reduce(T val) { + return simd_min(val); + }; + + static constexpr constant U init = Limits::max; + + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_fetch_min_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a < b ? a : b; + } +}; + +template +struct Max { + template + T simd_reduce(T val) { + return simd_max(val); + }; + + static constexpr constant U init = Limits::min; + + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_fetch_max_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a > b ? a : b; + } +}; diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal new file mode 100644 index 000000000..ba2368e25 --- /dev/null +++ b/mlx/backend/metal/kernels/scan.metal @@ -0,0 +1,492 @@ +#include +#include + +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +template +struct CumSum { + static constexpr constant U init = static_cast(0); + + template + U operator()(U a, T b) { + return a + b; + } + + U simd_scan(U x) { + return simd_prefix_inclusive_sum(x); + } + + U simd_exclusive_scan(U x) { + return simd_prefix_exclusive_sum(x); + } +}; + +template +struct CumProd { + static constexpr constant U init = static_cast(1.0f); + + template + U operator()(U a, T b) { + return a * b; + } + + U simd_scan(U x) { + return simd_prefix_inclusive_product(x); + } + + U simd_exclusive_scan(U x) { + return simd_prefix_exclusive_product(x); + } +}; + +template <> +struct CumProd { + static constexpr constant bool init = true; + + template + bool operator()(bool a, T b) { + return a & static_cast(b); + } + + bool simd_scan(bool x) { + for (int i=1; i<=16; i*=2) { + bool other = simd_shuffle_up(x, i); + x &= other; + } + return x; + } + + bool simd_exclusive_scan(bool x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMax { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return (a >= b) ? a : b; + } + + U simd_scan(U x) { + for (int i=1; i<=16; i*=2) { + U other = simd_shuffle_up(x, i); + x = (x >= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMin { + static constexpr constant U init = Limits::max; + + template + U operator()(U a, T b) { + return (a <= b) ? a : b; + } + + U simd_scan(U x) { + for (int i=1; i<=16; i*=2) { + U other = simd_shuffle_up(x, i); + x = (x <= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +inline void load_unsafe(U values[N_READS], const device T * input) { + if (reverse) { + for (int i=0; i +inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) { + if (reverse) { + for (int i=0; i +inline void write_unsafe(U values[N_READS], device U * out) { + if (reverse) { + for (int i=0; i +inline void write_safe(U values[N_READS], device U * out, int start, int total) { + if (reverse) { + for (int i=0; i +[[kernel]] void contiguous_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t & axis_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + + // Position the pointers + in += (gid / lsize) * axis_size; + out += (gid / lsize) * axis_size; + + // Compute the number of simd_groups + uint simd_groups = lsize / simd_size; + + // Allocate memory + U prefix = Op::init; + U values[N_READS]; + threadgroup U simdgroup_sums[32]; + + // Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize) + // Read block + // Compute inclusive scan of the block + // Compute inclusive scan per thread + // Compute exclusive scan of thread sums in simdgroup + // Write simdgroup sums in SM + // Compute exclusive scan of simdgroup sums + // Compute the output by scanning prefix, prev_simdgroup, prev_thread, value + // Write block + + for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) { + // Compute the block offset + uint offset = r*lsize*N_READS + lid*N_READS; + + // Read the values + if (reverse) { + if ((offset + N_READS) < axis_size) { + load_unsafe(values, in + axis_size - offset - N_READS); + } else { + load_safe(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init); + } + } else { + if ((offset + N_READS) < axis_size) { + load_unsafe(values, in + offset); + } else { + load_safe(values, in + offset, offset, axis_size, Op::init); + } + } + + // Compute an inclusive scan per thread + for (int i=1; i(values, out + axis_size - offset - N_READS); + } else { + write_safe(values, out + axis_size - offset - N_READS, offset, axis_size); + } + } else { + if (lid == 0 && offset == 0) { + out[axis_size-1] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe(values, out + axis_size - offset - 1 - N_READS); + } else { + write_safe(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size); + } + } + } else { + if (inclusive) { + if ((offset + N_READS) < axis_size) { + write_unsafe(values, out + offset); + } else { + write_safe(values, out + offset, offset, axis_size); + } + } else { + if (lid == 0 && offset == 0) { + out[0] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe(values, out + offset + 1); + } else { + write_safe(values, out + offset + 1, offset + 1, axis_size); + } + } + } + + // Share the prefix + if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { + simdgroup_sums[0] = values[N_READS-1]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + prefix = simdgroup_sums[0]; + } +} + +template +[[kernel]] void strided_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t & axis_size [[buffer(2)]], + const constant size_t & stride [[buffer(3)]], + uint2 gid [[threadgroup_position_in_grid]], + uint2 lid [[thread_position_in_threadgroup]], + uint2 lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]]) { + Op op; + + // Allocate memory + threadgroup U read_buffer[N_READS*32*32 + N_READS*32]; + U values[N_READS]; + U prefix[N_READS]; + for (int i=0; i, nreads, inclusive, reverse>( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t & axis_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_size [[threads_per_simdgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \ + template [[host_name("strided_scan_" #name)]] \ + [[kernel]] void strided_scan, nreads, inclusive, reverse>( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t & axis_size [[buffer(2)]], \ + const constant size_t & stride [[buffer(3)]], \ + uint2 gid [[thread_position_in_grid]], \ + uint2 lid [[thread_position_in_threadgroup]], \ + uint2 lsize [[threads_per_threadgroup]], \ + uint simd_size [[threads_per_simdgroup]]); + + +#define instantiate_scan_helper(name, itype, otype, op, nreads) \ + instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ + instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ + instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ + instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \ + instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ + instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ + instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ + instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) + +instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4) +instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4) +instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4) +instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4) +//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2) +instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4) +instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4) +instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4) +//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2) +instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4) +instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4) +//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4) +//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum) +//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4) +instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4) +instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4) +instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4) +//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2) +instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4) +instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4) +instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4) +//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2) +instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4) +instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4) +//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4) +//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd) +//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4) +instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4) +instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4) +instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4) +//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2) +instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4) +instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4) +instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4) +//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2) +instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4) +instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4) +//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4) +//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax) +//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4) +instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4) +instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4) +instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4) +//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2) +instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4) +instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4) +instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4) +//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2) +instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) +instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) +//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) +//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp new file mode 100644 index 000000000..d936ff3ed --- /dev/null +++ b/mlx/backend/metal/metal.cpp @@ -0,0 +1,88 @@ +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/backend/metal/device.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" + +namespace mlx::core::metal { + +int max_ops_per_buffer() { + auto get_val = []() { + if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) { + return atoi(buff_str); + } else { + return 10; + } + }; + static int max_ops_per_buffer_ = get_val(); + return max_ops_per_buffer_; +} + +#define MAX_OPS_PER_BUFFER max_ops_per_buffer() + +MTL::CommandBuffer* increment_command_buffer(Stream s) { + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); + if (command_buffer == nullptr || + d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { + if (command_buffer != nullptr) { + d.end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); }); + d.commit_command_buffer(s.index); + } + command_buffer = d.new_command_buffer(s.index); + } + d.increment_command_buffer_ops(s.index); + return command_buffer; +} + +std::function make_task( + array& arr, + std::vector> deps, + std::shared_ptr> p, + bool retain_graph) { + auto task = + [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { + for (auto& d : deps) { + d.wait(); + } + auto s = arr.primitive().stream(); + auto command_buffer = increment_command_buffer(s); + arr.primitive().eval_gpu(arr.inputs(), arr); + if (p) { + metal::device(s.device).end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [retain_graph, s, arr, p = std::move(p)]( + MTL::CommandBuffer*) mutable { + if (!retain_graph) { + arr.detach(); + } + p->set_value(); + // Signal this thread to clear the pool on a synchroniztion. + scheduler::enqueue(s, []() { + thread_autorelease_pool()->release(); + thread_autorelease_pool() = + NS::AutoreleasePool::alloc()->init(); + }); + scheduler::notify_task_completion(s); + }); + metal::device(s.device).commit_command_buffer(s.index); + } else { + command_buffer->addCompletedHandler( + [retain_graph, s, arr](MTL::CommandBuffer*) mutable { + if (!retain_graph) { + arr.detach(); + } + }); + } + }; + return task; +} + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h new file mode 100644 index 000000000..8195080f6 --- /dev/null +++ b/mlx/backend/metal/metal.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/stream.h" + +namespace mlx::core::metal { + +constexpr bool is_available() { +#ifdef _METAL_ + return true; +#else + return false; +#endif +} + +void new_stream(Stream stream); + +std::function make_task( + array& arr, + std::vector> deps, + std::shared_ptr> p, + bool retain_graph); + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp new file mode 100644 index 000000000..7f7b51b92 --- /dev/null +++ b/mlx/backend/metal/softmax.cpp @@ -0,0 +1,82 @@ +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void Softmax::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + if (!is_floating_point(out.dtype())) { + throw std::runtime_error( + "[softmax] Does not support non-floating point types."); + } + auto& s = stream(); + auto& d = metal::device(s.device); + + // Make sure that the last dimension is contiguous + std::vector copies; + auto check_input = [&copies, &s](const array& x) { + if (x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + copies.push_back(x_copy); + return x_copy; + } + }; + const array& in = check_input(inputs[0]); + out.set_data( + allocator::malloc_or_wait(in.data_size() * in.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + const int simd_size = 32; + const int n_reads = SOFTMAX_N_READS; + const int looped_limit = SOFTMAX_LOOPED_LIMIT; + std::string op_name = "softmax_"; + if (axis_size > looped_limit) { + op_name += "looped_"; + } + op_name += type_to_name(out); + auto compute_encoder = d.get_command_encoder(s.index); + { + auto kernel = d.get_kernel(op_name); + + MTL::Size grid_dims, group_dims; + if (axis_size <= looped_limit) { + size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; + size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; + size_t threadgroup_size = simd_size * simds_needed; + assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } else { + size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } + + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&axis_size, sizeof(int), 2); + compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0); + compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp new file mode 100644 index 000000000..8cffdd4eb --- /dev/null +++ b/mlx/backend/metal/sort.cpp @@ -0,0 +1,336 @@ +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void single_block_sort( + const Stream& s, + metal::Device& d, + const array& in, + array& out, + int axis, + int bn, + int tn) { + // Prepare shapes + int n_rows = in.size() / in.shape(axis); + + std::vector nc_str = in.strides(); + nc_str.erase(nc_str.begin() + axis); + + std::vector nc_shape = in.shape(); + nc_shape.erase(nc_shape.begin() + axis); + + int nc_dim = nc_shape.size(); + + int size_sorted_axis = in.shape(axis); + int stride_sorted_axis = in.strides()[axis]; + int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end()); + + // Check if remaining strides are contiguous + bool contiguous_write = true; + if (axis != in.ndim() - 1 && axis != 0) { + for (int i = 0; i < nc_str.size() - 1; ++i) { + size_t expected = nc_str[i + 1] * nc_str[i + 1]; + contiguous_write &= (nc_str[i] == expected); + } + } + + // Prepare kernel name + std::ostringstream kname; + if (ARGSORT) { + kname << "arg_"; + } + kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out) + << "_bn" << bn << "_tn" << tn; + + if (!contiguous_write) { + kname << "_nc"; + } + + // Prepare command encoder + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + // Set inputs + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2); + compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3); + + if (contiguous_write) { + compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4); + } else { + compute_encoder->setBytes(&nc_dim, sizeof(int), 4); + compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5); + compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6); + } + + MTL::Size group_dims = MTL::Size(bn, 1, 1); + MTL::Size grid_dims = MTL::Size(1, n_rows, 1); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); +} + +template +void multi_block_sort( + const Stream& s, + metal::Device& d, + const array& in, + array& out, + int axis, + int bn, + int tn, + int n_blocks) { + // Prepare shapes + int n_rows = in.size() / in.shape(axis); + + std::vector nc_str = in.strides(); + nc_str.erase(nc_str.begin() + axis); + + std::vector nc_shape = in.shape(); + nc_shape.erase(nc_shape.begin() + axis); + + int nc_dim = nc_shape.size(); + + int size_sorted_axis = in.shape(axis); + int stride_sorted_axis = in.strides()[axis]; + + // Make temporary copies + array dev_vals_0({n_rows, size_sorted_axis}, in.dtype(), nullptr, {}); + array dev_vals_1({n_rows, size_sorted_axis}, in.dtype(), nullptr, {}); + + array dev_idxs_0({n_rows, size_sorted_axis}, uint32, nullptr, {}); + array dev_idxs_1({n_rows, size_sorted_axis}, uint32, nullptr, {}); + + array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {}); + + // Do allocations + dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes())); + dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes())); + dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes())); + dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes())); + block_partitions.set_data( + allocator::malloc_or_wait(block_partitions.nbytes())); + + std::vector copies = { + dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions}; + + // Prepare command encoder + auto compute_encoder = d.get_command_encoder(s.index); + + // Do blockwise sort + { + std::ostringstream kname; + kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_" + << type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn; + + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, dev_vals_0, 1); + set_array_buffer(compute_encoder, dev_idxs_0, 2); + compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); + compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4); + compute_encoder->setBytes(&nc_dim, sizeof(int), 5); + compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6); + compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7); + + MTL::Size group_dims = MTL::Size(bn, 1, 1); + MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + + // Do merges + bool ping = false; + array dev_vals_in = dev_vals_0; + array dev_idxs_in = dev_idxs_0; + array dev_vals_out = dev_vals_1; + array dev_idxs_out = dev_idxs_1; + for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) { + dev_vals_in = ping ? dev_vals_1 : dev_vals_0; + dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0; + dev_vals_out = ping ? dev_vals_0 : dev_vals_1; + dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1; + ping = !ping; + + // Do partiton + { + std::ostringstream kname; + kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_" + << type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn; + + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + set_array_buffer(compute_encoder, block_partitions, 0); + set_array_buffer(compute_encoder, dev_vals_in, 1); + set_array_buffer(compute_encoder, dev_idxs_in, 2); + compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); + compute_encoder->setBytes(&merge_tiles, sizeof(int), 4); + + MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1); + MTL::Size grid_dims = MTL::Size(1, n_rows, 1); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + + // Do merge + { + std::ostringstream kname; + kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_" + << type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn; + + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + set_array_buffer(compute_encoder, block_partitions, 0); + set_array_buffer(compute_encoder, dev_vals_in, 1); + set_array_buffer(compute_encoder, dev_idxs_in, 2); + set_array_buffer(compute_encoder, dev_vals_out, 3); + set_array_buffer(compute_encoder, dev_idxs_out, 4); + compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5); + compute_encoder->setBytes(&merge_tiles, sizeof(int), 6); + compute_encoder->setBytes(&n_blocks, sizeof(int), 7); + + MTL::Size group_dims = MTL::Size(bn, 1, 1); + MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + } + + // Copy outputs with appropriate strides + array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out; + + if (axis == strided_out_arr.ndim() - 1) { + copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s); + } else { + std::vector strided_out_shape = strided_out_arr.shape(); + std::vector strided_out_str = strided_out_arr.strides(); + + int out_axis_shape = strided_out_shape[axis]; + int out_axis_str = strided_out_str[axis]; + + strided_out_shape.erase(strided_out_shape.begin() + axis); + strided_out_str.erase(strided_out_str.begin() + axis); + + strided_out_shape.push_back(out_axis_shape); + strided_out_str.push_back(out_axis_str); + + array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {}); + strided_out_slice.copy_shared_buffer( + strided_out_arr, + strided_out_str, + strided_out_arr.flags(), + strided_out_arr.size(), + 0); + + copy_gpu_inplace(strided_out_slice, out, CopyType::General, s); + } + + // Clear copies + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); +} + +template +void gpu_merge_sort( + const Stream& s, + metal::Device& d, + const array& in, + array& out, + int axis_) { + // Get size info + int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; + int size_sorted_axis = in.shape(axis); + + // Get kernel size + int tn = 8; + int bn = 128; + int potential_bn = (size_sorted_axis + tn - 1) / tn; + + if (potential_bn > 256) { + bn = 512; + } else if (potential_bn > 128) { + bn = 256; + } else { + bn = 128; + } + + if (bn == 512 && size_of(in.dtype()) > 4) { + bn = 256; + } + + int n_per_block = bn * tn; + int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block; + + if (n_blocks > 1) { + return multi_block_sort(s, d, in, out, axis, bn, tn, n_blocks); + } else { + return single_block_sort(s, d, in, out, axis, bn, tn); + } +} + +} // namespace + +void ArgSort::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& s = stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; + + gpu_merge_sort(s, d, in, out, axis_); +} + +void Sort::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& s = stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; + + gpu_merge_sort(s, d, in, out, axis_); +} + +void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { + // We direct arg partition to sort for now + assert(inputs.size() == 1); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& s = stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; + + gpu_merge_sort(s, d, in, out, axis_); +} + +void Partition::eval_gpu(const std::vector& inputs, array& out) { + // We direct partition to sort for now + assert(inputs.size() == 1); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& s = stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; + + gpu_merge_sort(s, d, in, out, axis_); +} + +} // namespace mlx::core diff --git a/mlx/backend/no_metal/CMakeLists.txt b/mlx/backend/no_metal/CMakeLists.txt new file mode 100644 index 000000000..6aaa766d6 --- /dev/null +++ b/mlx/backend/no_metal/CMakeLists.txt @@ -0,0 +1,7 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp +) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp new file mode 100644 index 000000000..7a42e3ff6 --- /dev/null +++ b/mlx/backend/no_metal/primitives.cpp @@ -0,0 +1,77 @@ + +#include "mlx/primitives.h" + +#define NO_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + throw std::runtime_error(#func " has no GPU implementation."); \ + } + +namespace mlx::core { + +NO_GPU(Abs) +NO_GPU(Add) +NO_GPU(Arange) +NO_GPU(ArcCos) +NO_GPU(ArcCosh) +NO_GPU(ArcSin) +NO_GPU(ArcSinh) +NO_GPU(ArcTan) +NO_GPU(ArcTanh) +NO_GPU(ArgPartition) +NO_GPU(ArgReduce) +NO_GPU(ArgSort) +NO_GPU(AsType) +NO_GPU(AsStrided) +NO_GPU(Broadcast) +NO_GPU(Concatenate) +NO_GPU(Convolution) +NO_GPU(Copy) +NO_GPU(Cos) +NO_GPU(Cosh) +NO_GPU(Divide) +NO_GPU(Equal) +NO_GPU(Erf) +NO_GPU(ErfInv) +NO_GPU(Exp) +NO_GPU(FFT) +NO_GPU(Full) +NO_GPU(Gather) +NO_GPU(Greater) +NO_GPU(GreaterEqual) +NO_GPU(Less) +NO_GPU(LessEqual) +NO_GPU(Load) +NO_GPU(Log) +NO_GPU(Log1p) +NO_GPU(LogicalNot) +NO_GPU(LogAddExp) +NO_GPU(Matmul) +NO_GPU(Maximum) +NO_GPU(Minimum) +NO_GPU(Multiply) +NO_GPU(Negative) +NO_GPU(NotEqual) +NO_GPU(Pad) +NO_GPU(Partition) +NO_GPU(Power) +NO_GPU(RandomBits) +NO_GPU(Reduce) +NO_GPU(Reshape) +NO_GPU(Scan) +NO_GPU(Scatter) +NO_GPU(Sigmoid) +NO_GPU(Sign) +NO_GPU(Sin) +NO_GPU(Sinh) +NO_GPU(Slice) +NO_GPU(Softmax) +NO_GPU(Sort) +NO_GPU(Square) +NO_GPU(Sqrt) +NO_GPU(StopGradient) +NO_GPU(Subtract) +NO_GPU(Tan) +NO_GPU(Tanh) +NO_GPU(Transpose) + +} // namespace mlx::core diff --git a/mlx/device.cpp b/mlx/device.cpp new file mode 100644 index 000000000..352b0284c --- /dev/null +++ b/mlx/device.cpp @@ -0,0 +1,29 @@ +#include "mlx/device.h" +#include "mlx/backend/metal/metal.h" + +namespace mlx::core { + +static Device default_device_{ + metal::is_available() ? Device::gpu : Device::cpu}; + +const Device& default_device() { + return default_device_; +} + +void set_default_device(const Device& d) { + if (!metal::is_available() && d == Device::gpu) { + throw std::invalid_argument( + "[set_default_device] Cannot set gpu device without gpu backend."); + } + default_device_ = d; +} + +bool operator==(const Device& lhs, const Device& rhs) { + return lhs.type == rhs.type && lhs.index == rhs.index; +} + +bool operator!=(const Device& lhs, const Device& rhs) { + return !(lhs == rhs); +} + +} // namespace mlx::core diff --git a/mlx/dtype.h b/mlx/dtype.h new file mode 100644 index 000000000..6a5f1179b --- /dev/null +++ b/mlx/dtype.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include + +#include "mlx/types/complex.h" +#include "mlx/types/half_types.h" + +namespace mlx::core { + +struct Dtype { + enum class Val { + bool_, + uint8, + uint16, + uint32, + uint64, + int8, + int16, + int32, + int64, + float16, + float32, + bfloat16, + complex64, + }; + + enum class Kind { + b, /* bool */ + u, /* unsigned int */ + i, /* signed int */ + f, /* float */ + c, /* complex */ + V, /* void - used for brain float */ + }; + + Val val; + const uint8_t size; + constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){}; + constexpr operator Val() const { + return val; + }; +}; + +inline bool is_available(const Dtype& dtype) { + return true; +} + +static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; + +static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)}; +static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)}; +static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)}; +static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)}; + +static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)}; +static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)}; +static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)}; +static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)}; + +static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)}; +static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)}; +static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; +static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; + +Dtype promote_types(const Dtype& t1, const Dtype& t2); + +inline uint8_t size_of(const Dtype& t) { + return t.size; +} + +Dtype::Kind kindof(const Dtype& t); + +inline bool is_unsigned(const Dtype& t) { + return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b; +} + +inline bool is_floating_point(const Dtype& t) { + return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V || + kindof(t) == Dtype::Kind::c; +} + +inline bool is_integral(const Dtype& t) { + return !(is_floating_point(t)); +} + +template +struct TypeToDtype { + operator Dtype(); +}; + +// Array protocol typestring for Dtype +std::string dtype_to_array_protocol(const Dtype& t); +// Dtype from array protocol type string +Dtype dtype_from_array_protocol(const std::string& t); + +} // namespace mlx::core diff --git a/mlx/fft.cpp b/mlx/fft.cpp new file mode 100644 index 000000000..96dc76bab --- /dev/null +++ b/mlx/fft.cpp @@ -0,0 +1,190 @@ +#include +#include + +#include "mlx/fft.h" +#include "mlx/ops.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core::fft { + +array fft_impl( + const array& a, + std::vector n, + const std::vector& axes, + bool real, + bool inverse, + StreamOrDevice s) { + if (a.ndim() < 1) { + throw std::invalid_argument( + "[fftn] Requires array with at least one dimension."); + } + if (n.size() != axes.size()) { + throw std::invalid_argument("[fftn] Shape and axes have different sizes."); + } + if (axes.empty()) { + return a; + } + + std::vector valid_axes; + for (int ax : axes) { + valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax); + } + std::set unique_axes(valid_axes.begin(), valid_axes.end()); + if (unique_axes.size() != axes.size()) { + std::ostringstream msg; + msg << "[fftn] Duplicated axis received " << axes; + throw std::invalid_argument(msg.str()); + } + if (*unique_axes.begin() < 0 || *unique_axes.rbegin() >= a.ndim()) { + std::ostringstream msg; + msg << "[fftn] Invalid axis received for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + // In the following shape manipulations there are three cases to consdier: + // 1. In a complex to complex transform (fftn / ifftn) the output + // and input shapes are the same. + // 2. In a real to complex transform (rfftn) n specifies the input dims + // and the output dims are n[i] / 2 + 1 + // 3 In a complex to real transform (irfftn) n specifies the output dims + // and the input dims are n[i] / 2 + 1 + + if (std::any_of(n.begin(), n.end(), [](auto i) { return i <= 0; })) { + std::ostringstream msg; + msg << "[fftn] Invalid FFT output size requested " << n; + throw std::invalid_argument(msg.str()); + } + + std::vector in_shape = a.shape(); + for (int i = 0; i < valid_axes.size(); ++i) { + in_shape[valid_axes[i]] = n[i]; + } + if (real && inverse) { + in_shape[valid_axes.back()] = n.back() / 2 + 1; + } + + bool any_greater = false; + bool any_less = false; + for (int i = 0; i < in_shape.size(); ++i) { + any_greater |= in_shape[i] > a.shape()[i]; + any_less |= in_shape[i] < a.shape()[i]; + } + + auto in = a; + if (any_less) { + in = slice(in, std::vector(in.ndim(), 0), in_shape, s); + } + if (any_greater) { + // Pad with zeros + auto tmp = zeros(in_shape, a.dtype(), s); + in = scatter(tmp, std::vector{}, in, std::vector{}, s); + } + + auto out_shape = in_shape; + if (real) { + auto ax = valid_axes.back(); + out_shape[ax] = inverse ? n.back() : out_shape[ax] / 2 + 1; + } + + auto in_type = real && !inverse ? float32 : complex64; + auto out_type = real && inverse ? float32 : complex64; + return array( + out_shape, + out_type, + std::make_unique(to_stream(s), valid_axes, inverse, real), + {astype(in, in_type, s)}); +} + +array fft_impl( + const array& a, + const std::vector& axes, + bool real, + bool inverse, + StreamOrDevice s) { + std::vector n; + for (auto ax : axes) { + n.push_back(a.shape(ax)); + } + if (real && inverse) { + n.back() = (n.back() - 1) * 2; + } + return fft_impl(a, n, axes, real, inverse, s); +} + +array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return fft_impl(a, axes, real, inverse, s); +} + +array fftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + return fft_impl(a, n, axes, false, false, s); +} +array fftn( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + return fft_impl(a, axes, false, false, s); +} +array fftn(const array& a, StreamOrDevice s /* = {} */) { + return fft_impl(a, false, false, s); +} + +array ifftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + return fft_impl(a, n, axes, false, true, s); +} +array ifftn( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + return fft_impl(a, axes, false, true, s); +} +array ifftn(const array& a, StreamOrDevice s /* = {} */) { + return fft_impl(a, false, true, s); +} + +array rfftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + return fft_impl(a, n, axes, true, false, s); +} +array rfftn( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + return fft_impl(a, axes, true, false, s); +} +array rfftn(const array& a, StreamOrDevice s /* = {} */) { + return fft_impl(a, true, false, s); +} + +array irfftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + return fft_impl(a, n, axes, true, true, s); +} +array irfftn( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + return fft_impl(a, axes, true, true, s); +} +array irfftn(const array& a, StreamOrDevice s /* = {} */) { + return fft_impl(a, true, true, s); +} + +} // namespace mlx::core::fft diff --git a/mlx/graph_utils.h b/mlx/graph_utils.h new file mode 100644 index 000000000..e0696e5ef --- /dev/null +++ b/mlx/graph_utils.h @@ -0,0 +1,21 @@ +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void print_graph(std::ostream& os, const std::vector& outputs); + +template +void print_graph(std::ostream& os, Arrays... outputs) { + print_graph(os, std::vector{std::forward(outputs)...}); +} + +void export_to_dot(std::ostream& os, const std::vector& outputs); + +template +void export_to_dot(std::ostream& os, Arrays... outputs) { + export_to_dot(os, std::vector{std::forward(outputs)...}); +} + +} // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp new file mode 100644 index 000000000..c851a44ac --- /dev/null +++ b/mlx/ops.cpp @@ -0,0 +1,2323 @@ +#include +#include +#include +#include + +#include "mlx/ops.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +std::pair, std::vector> compute_reduce_shape( + const std::vector& axes, + const std::vector& shape, + bool keepdims) { + std::set axes_set; + auto ndim = shape.size(); + for (auto ax : axes) { + int ax_ = (ax < 0) ? ax + ndim : ax; + if (ax_ < 0 || ax_ >= ndim) { + std::ostringstream msg; + msg << "Invalid axis " << ax << " for array with " << ndim + << " dimensions."; + throw std::out_of_range(msg.str()); + } + axes_set.insert(ax_); + } + if (axes_set.size() != axes.size()) { + throw std::invalid_argument("Duplicate axes detected in reduction."); + } + std::vector out_shape; + for (int i = 0; i < ndim; ++i) { + if (axes_set.count(i) == 0) { + out_shape.push_back(shape[i]); + } else if (keepdims) { + out_shape.push_back(1); + } + } + std::vector sorted_axes(axes_set.begin(), axes_set.end()); + return {out_shape, sorted_axes}; +} + +int compute_number_of_elements(const array& a, const std::vector& axes) { + int nelements = 1; + for (auto axis : axes) { + nelements *= a.shape(axis); + } + return nelements; +} + +Dtype at_least_float(const Dtype& d) { + return is_floating_point(d) ? d : promote_types(d, float32); +} + +} // namespace + +Stream to_stream(StreamOrDevice s) { + if (std::holds_alternative(s)) { + return default_stream(default_device()); + } else if (std::holds_alternative(s)) { + return default_stream(std::get(s)); + } else { + return std::get(s); + } +} + +array arange( + double start, + double stop, + double step, + Dtype dtype, + StreamOrDevice s /* = {} */) { + if (dtype == bool_) { + std::ostringstream msg; + msg << bool_ << " not supported for arange."; + throw std::invalid_argument(msg.str()); + } + int size = std::max(static_cast(std::ceil((stop - start) / step)), 0); + return array( + {size}, + dtype, + std::make_unique(to_stream(s), start, stop, step), + {}); +} +array arange( + double start, + double stop, + double step, + StreamOrDevice s /* = {} */) { + return arange(start, stop, step, float32, to_stream(s)); +} +array arange( + double start, + double stop, + Dtype dtype, + StreamOrDevice s /* = {} */) { + return arange(start, stop, 1.0, dtype, to_stream(s)); +} +array arange(double start, double stop, StreamOrDevice s /* = {} */) { + return arange(start, stop, 1.0, float32, to_stream(s)); +} +array arange(double stop, Dtype dtype, StreamOrDevice s /* = {} */) { + return arange(0.0, stop, 1.0, dtype, to_stream(s)); +} +array arange(double stop, StreamOrDevice s /* = {} */) { + return arange(0.0, stop, 1.0, float32, to_stream(s)); +} +array arange(int start, int stop, int step, StreamOrDevice s /* = {} */) { + return arange( + static_cast(start), + static_cast(stop), + static_cast(step), + int32, + to_stream(s)); +} +array arange(int start, int stop, StreamOrDevice s /* = {} */) { + return arange( + static_cast(start), + static_cast(stop), + 1.0, + int32, + to_stream(s)); +} +array arange(int stop, StreamOrDevice s /* = {} */) { + return arange(0.0, static_cast(stop), 1.0, int32, to_stream(s)); +} + +array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) { + if (dtype == a.dtype()) { + return a; + } + return array( + a.shape(), dtype, std::make_unique(to_stream(s), dtype), {a}); +} + +array as_strided( + const array& a, + std::vector shape, + std::vector strides, + size_t offset, + StreamOrDevice s /* = {} */) { + // Force the input array to be contiguous + auto x = reshape(a, {-1}, s); + return array( + shape, + a.dtype(), + std::make_unique(to_stream(s), shape, strides, offset), + {x}); +} + +array copy(const array& a, StreamOrDevice s /* = {} */) { + return array(a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); +} + +array full( + const std::vector& shape, + const array& vals, + Dtype dtype, + StreamOrDevice s /* = {} */) { + auto in = broadcast_to(astype(vals, dtype, s), shape, s); + return array(shape, dtype, std::make_unique(to_stream(s)), {in}); +} + +array full( + const std::vector& shape, + const array& vals, + StreamOrDevice s /* = {} */) { + return full(shape, vals, vals.dtype(), to_stream(s)); +} + +array zeros( + const std::vector& shape, + Dtype dtype, + StreamOrDevice s /* = {} */) { + return full(shape, array(0, dtype), to_stream(s)); +} + +array zeros_like(const array& a, StreamOrDevice s /* = {} */) { + return zeros(a.shape(), a.dtype(), to_stream(s)); +} + +array ones( + const std::vector& shape, + Dtype dtype, + StreamOrDevice s /* = {} */) { + return full(shape, array(1, dtype), to_stream(s)); +} + +array ones_like(const array& a, StreamOrDevice s /* = {} */) { + return ones(a.shape(), a.dtype(), to_stream(s)); +} + +array reshape( + const array& a, + std::vector shape, + StreamOrDevice s /* = {} */) { + if (a.shape() == shape) { + return a; + } + + size_t size = 1; + int infer_idx = -1; + for (int i = 0; i < shape.size(); ++i) { + if (shape[i] == -1) { + if (infer_idx >= 0) { + throw std::invalid_argument("Reshape can only infer one dimension."); + } + infer_idx = i; + } else { + size *= shape[i]; + } + } + if (size > 0) { + auto q_and_r = std::ldiv(a.size(), size); + if (infer_idx >= 0) { + shape[infer_idx] = q_and_r.quot; + size *= q_and_r.quot; + } + } + if (a.size() != size) { + std::ostringstream msg; + msg << "Cannot reshape array of size " << a.size() << " into shape " + << shape << "."; + throw std::invalid_argument(msg.str()); + } + return array( + shape, a.dtype(), std::make_unique(to_stream(s), shape), {a}); +} + +array squeeze( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + std::set unique_axes; + for (auto ax : axes) { + ax = ax < 0 ? ax + a.ndim() : ax; + if (ax < 0 || ax >= a.ndim()) { + std::ostringstream msg; + msg << "[squeeze] Invalid axies " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (a.shape(ax) != 1) { + std::ostringstream msg; + msg << "[squeeze] Cannot squeeze axis " << ax << " with size " + << a.shape(ax) << " which is not equal to 1."; + throw std::invalid_argument(msg.str()); + } + unique_axes.insert(ax); + } + + if (unique_axes.size() != axes.size()) { + throw std::invalid_argument("[squeeze] Received duplicate axes."); + } + std::vector sorted_axes(unique_axes.begin(), unique_axes.end()); + std::vector shape; + for (int i = 0, j = 0; i < a.ndim(); ++i) { + if (j < sorted_axes.size() && i == sorted_axes[j]) { + j++; + } else { + shape.push_back(a.shape(i)); + } + } + return reshape(a, shape, s); +} + +array squeeze(const array& a, StreamOrDevice s /* = {} */) { + std::vector axes; + for (int i = 0; i < a.ndim(); ++i) { + if (a.shape(i) == 1) { + axes.push_back(i); + } + } + return squeeze(a, axes, s); +} + +array expand_dims( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + { // Check for repeats + std::set unique_axes(axes.begin(), axes.end()); + if (unique_axes.size() != axes.size()) { + throw std::invalid_argument("[expand_dims] Received duplicate axes."); + } + } + + int out_ndim = axes.size() + a.ndim(); + std::vector canonical_axes = axes; + for (auto& ax : canonical_axes) { + ax = ax < 0 ? ax + out_ndim : ax; + if (ax < 0 || ax >= out_ndim) { + std::ostringstream msg; + msg << "[squeeze] Invalid axies " << ax << " for output array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + } + + // Check for repeats again + std::set unique_axes(canonical_axes.begin(), canonical_axes.end()); + if (unique_axes.size() != axes.size()) { + throw std::invalid_argument("[expand_dims] Received duplicate axes."); + } + + std::vector sorted_axes(unique_axes.begin(), unique_axes.end()); + auto out_shape = a.shape(); + for (int i = 0; i < sorted_axes.size(); ++i) { + out_shape.insert(out_shape.begin() + sorted_axes[i], 1); + } + return reshape(a, out_shape, s); +} + +array slice( + const array& a, + std::vector start, + std::vector stop, + std::vector strides, + StreamOrDevice s /* = {} */) { + if (start.size() != a.ndim() || stop.size() != a.ndim() || + strides.size() != a.ndim()) { + std::ostringstream msg; + msg << "[slice] Invalid number of indices or strides for " + << "array with dimension " << a.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + std::vector negatively_strided_axes; + std::vector> negatively_strided_slices; + std::vector out_shape(a.ndim()); + for (int i = 0; i < a.ndim(); ++i) { + // Following numpy docs + // Negative i and j are interpreted as n + i and n + j where n is + // the number of elements in the corresponding dimension. Negative + // k makes stepping go towards smaller indices + + auto n = a.shape(i); + auto s = start[i]; + s = s < 0 ? s + n : s; + auto e = stop[i]; + e = e < 0 ? e + n : e; + + // Note: We pass positive strides to the primitive and then flip + // the axes later as needed + if (strides[i] < 0) { + negatively_strided_axes.push_back(i); + auto st = std::min(s, n - 1); + auto ed = std::max(e, -1); + negatively_strided_slices.push_back({st, ed, strides[i]}); + start[i] = 0; + stop[i] = n; + strides[i] = 1; + } else { + start[i] = s; + stop[i] = e < s ? s : e; + } + + // Clamp to bounds + start[i] = std::max(0, std::min(start[i], n)); + stop[i] = std::max(0, std::min(stop[i], n)); + + out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i]; + } + + // If strides are negative, slice and then make a copy with axes flipped + if (negatively_strided_axes.size() > 0) { + // First, take the slice of the positvely strided axes + auto out = array( + out_shape, + a.dtype(), + std::make_unique( + to_stream(s), + std::move(start), + std::move(stop), + std::move(strides)), + {a}); + + std::vector indices; + std::vector slice_sizes = out.shape(); + std::vector t_axes(out.ndim(), -1); + std::vector out_reshape(out.ndim(), -1); + + int n_axes = negatively_strided_axes.size(); + for (int i = 0; i < n_axes; i++) { + // Get axis and corresponding slice + auto ax = negatively_strided_axes[i]; + auto sl = negatively_strided_slices[i]; + + // Get indices for the slice + auto ax_idx = arange(sl[0], sl[1], sl[2], s); + + // Reshape indices for broadcast as needed + std::vector ax_idx_shape(n_axes, 1); + ax_idx_shape[i] = ax_idx.size(); + ax_idx = reshape(ax_idx, ax_idx_shape, s); + + // Add indices to list + indices.push_back(ax_idx); + + // Set slice size for axis + slice_sizes[ax] = 1; + + // Gather moves the axis up, remainder needs to be squeezed + out_reshape[i] = indices[i].size(); + + // Gather moves the axis up, needs to be tranposed + t_axes[ax] = i; + } + + // Prepare out_reshape to squeeze gathered dims + // Prepare to transpose dims as needed + int j = n_axes; + for (int i = 0; j < out.ndim() && i < out.ndim(); i++) { + if (t_axes[i] < 0) { + t_axes[i] = j; + out_reshape[j] = out_shape[i]; + j++; + } + } + + // Gather + out = gather(out, indices, negatively_strided_axes, slice_sizes, s); + + // Squeeze dims + out = reshape(out, out_reshape, s); + + // Transpose dims + out = transpose(out, t_axes, s); + + return out; + } + if (out_shape == a.shape()) { + return a; + } + return array( + out_shape, + a.dtype(), + std::make_unique( + to_stream(s), std::move(start), std::move(stop), std::move(strides)), + {a}); +} + +array slice( + const array& a, + const std::vector& start, + const std::vector& stop, + StreamOrDevice s /* = {} */) { + return slice(a, start, stop, std::vector(a.ndim(), 1), to_stream(s)); +} + +std::vector split( + const array& a, + const std::vector& indices, + int axis, + StreamOrDevice s /* = {} */) { + auto ax = axis < 0 ? axis + a.ndim() : axis; + if (ax < 0 || ax >= a.ndim()) { + std::ostringstream msg; + msg << "Invalid axis (" << axis << ") passed to split" + << " for array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } + std::vector res; + auto out_shape = a.shape(); + auto start_indices = std::vector(a.ndim(), 0); + auto stop_indices = a.shape(); + for (int i = 0; i < indices.size() + 1; ++i) { + stop_indices[ax] = i < indices.size() ? indices[i] : a.shape(ax); + res.push_back(slice(a, start_indices, stop_indices, to_stream(s))); + start_indices[ax] = stop_indices[ax]; + } + return res; +} + +std::vector split( + const array& a, + const std::vector& indices, + StreamOrDevice s /* = {} */) { + return split(a, indices, 0, s); +} + +std::vector +split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) { + auto q_and_r = std::ldiv(a.shape(axis), num_splits); + if (q_and_r.rem) { + std::ostringstream msg; + msg << "Array split does not result in sub arrays with equal size:" + << " attempting " << num_splits << " splits along axis " << axis + << " for shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } + auto split_size = q_and_r.quot; + std::vector indices(num_splits - 1); + for (int i = 0; i < indices.size(); ++i) { + indices[i] = (i + 1) * split_size; + } + return split(a, indices, axis, s); +} + +std::vector +split(const array& a, int num_splits, StreamOrDevice s /* = {} */) { + return split(a, num_splits, 0, to_stream(s)); +} + +array concatenate( + const std::vector& arrays, + int axis, + StreamOrDevice s /* = {} */) { + if (arrays.size() == 0) { + throw std::invalid_argument("No arrays provided for concatenation"); + } + + // Normalize the given axis + auto ax = axis < 0 ? axis + arrays[0].ndim() : axis; + if (ax < 0 || ax >= arrays[0].ndim()) { + std::ostringstream msg; + msg << "Invalid axis (" << axis << ") passed to concatenate" + << " for array with shape " << arrays[0].shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto throw_invalid_shapes = [&]() { + std::ostringstream msg; + msg << "All the input array dimensions must match exactly except" + << " for the concatenation axis. However, the provided shapes are "; + for (auto& a : arrays) { + msg << a.shape() << ", "; + } + msg << "and the concatenation axis is " << axis; + throw std::invalid_argument(msg.str()); + }; + + std::vector shape = arrays[0].shape(); + shape[ax] = 0; + // Make the output shape and validate that all arrays have the same shape + // except for the concatenation axis. + for (auto& a : arrays) { + for (int i = 0; i < a.ndim(); i++) { + if (i == ax) { + continue; + } + if (a.shape(i) != shape[i]) { + throw_invalid_shapes(); + } + } + shape[ax] += a.shape(ax); + } + + return array( + shape, + arrays[0].dtype(), + std::make_unique(to_stream(s), ax), + arrays); +} + +array concatenate( + const std::vector& arrays, + StreamOrDevice s /* = {} */) { + std::vector flat_inputs; + for (auto& a : arrays) { + flat_inputs.push_back(reshape(a, {-1}, s)); + } + return concatenate(flat_inputs, 0, s); +} + +/** Pad an array with a constant value */ +array pad( + const array& a, + const std::vector& axes, + const std::vector& low_pad_size, + const std::vector& high_pad_size, + const array& pad_value /*= array(0)*/, + StreamOrDevice s /* = {}*/) { + if (axes.size() != low_pad_size.size() || + axes.size() != high_pad_size.size()) { + std::ostringstream msg; + msg << "Invalid number of padding sizes passed to pad " + << "with axes of size " << axes.size(); + throw std::invalid_argument(msg.str()); + } + + std::vector out_shape = a.shape(); + + for (int i = 0; i < axes.size(); i++) { + if (low_pad_size[i] < 0) { + std::ostringstream msg; + msg << "Invalid low padding size (" << low_pad_size[i] + << ") passed to pad" + << " for axis " << i << ". Padding sizes must be non-negative"; + throw std::invalid_argument(msg.str()); + } + if (high_pad_size[i] < 0) { + std::ostringstream msg; + msg << "Invalid high padding size (" << high_pad_size[i] + << ") passed to pad" + << " for axis " << i << ". Padding sizes must be non-negative"; + throw std::invalid_argument(msg.str()); + } + + auto ax = axes[i] < 0 ? a.ndim() + axes[i] : axes[i]; + out_shape[ax] += low_pad_size[i] + high_pad_size[i]; + } + + return array( + out_shape, + a.dtype(), + std::make_unique(to_stream(s), axes, low_pad_size, high_pad_size), + {a, astype(pad_value, a.dtype(), s)}); +} + +/** Pad an array with a constant value along all axes */ +array pad( + const array& a, + const std::vector>& pad_width, + const array& pad_value /*= array(0)*/, + StreamOrDevice s /*= {}*/) { + std::vector axes(a.ndim(), 0); + std::iota(axes.begin(), axes.end(), 0); + + std::vector lows; + std::vector highs; + + for (auto& pads : pad_width) { + lows.push_back(pads.first); + highs.push_back(pads.second); + } + + return pad(a, axes, lows, highs, pad_value, s); +} + +array pad( + const array& a, + const std::pair& pad_width, + const array& pad_value /*= array(0)*/, + StreamOrDevice s /*= {}*/) { + return pad( + a, std::vector>(a.ndim(), pad_width), pad_value, s); +} + +array pad( + const array& a, + int pad_width, + const array& pad_value /*= array(0)*/, + StreamOrDevice s /*= {}*/) { + return pad( + a, + std::vector>(a.ndim(), {pad_width, pad_width}), + pad_value, + s); +} + +array transpose( + const array& a, + std::vector axes, + StreamOrDevice s /* = {} */) { + for (auto& ax : axes) { + ax = ax < 0 ? ax + a.ndim() : ax; + } + std::set dims(axes.begin(), axes.end()); + if (dims.size() != axes.size()) { + throw std::invalid_argument("Repeat axes not allowed in transpose."); + } + if (dims.size() != a.ndim() || + a.ndim() > 0 && + (*dims.begin() != 0 || *dims.rbegin() != (a.ndim() - 1))) { + throw std::invalid_argument("Transpose axes don't match array dimensions."); + } + std::vector shape; + shape.reserve(axes.size()); + for (auto ax : axes) { + shape.push_back(a.shape()[ax]); + } + return array( + shape, + a.dtype(), + std::make_unique(to_stream(s), std::move(axes)), + {a}); +} + +array transpose(const array& a, StreamOrDevice s /* = {} */) { + std::vector axes(a.ndim()); + std::iota(axes.rbegin(), axes.rend(), 0); + return transpose(a, std::move(axes), to_stream(s)); +} + +array broadcast_to( + const array& a, + const std::vector& shape, + StreamOrDevice s /* = {} */) { + if (a.shape() == shape) { + return a; + } + + // Make sure the shapes are broadcastable + auto bxshape = broadcast_shapes(a.shape(), shape); + if (bxshape != shape) { + std::ostringstream msg; + msg << "Cannot broadcast array of shape " << a.shape() << " into shape " + << shape << "."; + throw std::invalid_argument(msg.str()); + } + return array( + shape, a.dtype(), std::make_unique(to_stream(s), shape), {a}); +} + +std::vector broadcast_arrays( + const std::vector& inputs, + StreamOrDevice s /* = {} */) { + std::vector shape{}; + for (const auto& in : inputs) { + shape = broadcast_shapes(shape, in.shape()); + } + std::vector outputs; + for (const auto& in : inputs) { + outputs.push_back(broadcast_to(in, shape, s)); + } + return outputs; +} + +array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; + if (a.shape() != b.shape()) { + inputs = broadcast_arrays(inputs, s); + } + return array( + inputs[0].shape(), bool_, std::make_unique(to_stream(s)), inputs); +} + +array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; + if (a.shape() != b.shape()) { + inputs = broadcast_arrays(inputs, s); + } + return array( + inputs[0].shape(), + bool_, + std::make_unique(to_stream(s)), + inputs); +} + +array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; + if (a.shape() != b.shape()) { + inputs = broadcast_arrays(inputs, s); + } + return array( + inputs[0].shape(), + bool_, + std::make_unique(to_stream(s)), + inputs); +} + +array greater_equal( + const array& a, + const array& b, + StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; + if (a.shape() != b.shape()) { + inputs = broadcast_arrays(inputs, s); + } + return array( + inputs[0].shape(), + bool_, + std::make_unique(to_stream(s)), + inputs); +} + +array less(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; + if (a.shape() != b.shape()) { + inputs = broadcast_arrays(inputs, s); + } + return array( + inputs[0].shape(), bool_, std::make_unique(to_stream(s)), inputs); +} + +array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; + if (a.shape() != b.shape()) { + inputs = broadcast_arrays(inputs, s); + } + return array( + inputs[0].shape(), + bool_, + std::make_unique(to_stream(s)), + inputs); +} + +array array_equal( + const array& a, + const array& b, + bool equal_nan, + StreamOrDevice s /* = {} */) { + if (a.shape() != b.shape()) { + return array(false); + } else { + auto dtype = promote_types(a.dtype(), b.dtype()); + equal_nan &= is_floating_point(dtype); + return all( + array( + a.shape(), + bool_, + std::make_unique(to_stream(s), equal_nan), + {astype(a, dtype, s), astype(b, dtype, s)}), + false, + s); + } +} + +array where( + const array& condition, + const array& x, + const array& y, + StreamOrDevice s /* = {} */) { + // TODO, fix this to handle the NaN case when x has infs + auto mask = astype(condition, bool_, s); + return add(multiply(x, mask, s), multiply(y, logical_not(mask, s), s), s); +} + +array allclose( + const array& a, + const array& b, + double rtol /* = 1e-5 */, + double atol /* = 1e-8 */, + StreamOrDevice s /* = {}*/) { + // |a - b| <= atol + rtol * |b| + auto rhs = add(array(atol), multiply(array(rtol), abs(b, s), s), s); + auto lhs = abs(subtract(a, b, s), s); + return all(less_equal(lhs, rhs, s), s); +} + +array all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return all(a, axes, keepdims, s); +} + +array all( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + StreamOrDevice s /* = {}*/) { + if (axes.empty()) { + return astype(a, bool_, s); + } + auto [out_shape, sorted_axes] = + compute_reduce_shape(axes, a.shape(), keepdims); + return array( + out_shape, + bool_, + std::make_unique(to_stream(s), Reduce::And, sorted_axes), + {a}); +} + +array all( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + return all(a, std::vector{axis}, keepdims, s); +} + +array any(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return any(a, axes, keepdims, s); +} + +array any( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + StreamOrDevice s /* = {}*/) { + if (axes.empty()) { + return astype(a, bool_, s); + } + auto [out_shape, sorted_axes] = + compute_reduce_shape(axes, a.shape(), keepdims); + return array( + out_shape, + bool_, + std::make_unique(to_stream(s), Reduce::Or, sorted_axes), + {a}); +} + +array any( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + return any(a, std::vector{axis}, keepdims, s); +} + +array sum(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return sum(a, axes, keepdims, s); +} + +array sum( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + StreamOrDevice s /* = {}*/) { + if (axes.empty()) { + return a; + } + auto [out_shape, sorted_axes] = + compute_reduce_shape(axes, a.shape(), keepdims); + auto out_type = a.dtype() == bool_ ? int32 : a.dtype(); + return array( + out_shape, + out_type, + std::make_unique(to_stream(s), Reduce::Sum, sorted_axes), + {a}); +} + +array sum( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + return sum(a, std::vector{axis}, keepdims, s); +} + +array mean(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return mean(a, axes, keepdims, to_stream(s)); +} + +array mean( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + StreamOrDevice s /* = {}*/) { + auto nelements = compute_number_of_elements(a, axes); + auto dtype = at_least_float(a.dtype()); + return multiply(sum(a, axes, keepdims, s), array(1.0 / nelements, dtype), s); +} + +array mean( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + return mean(a, std::vector{axis}, keepdims, to_stream(s)); +} + +array var( + const array& a, + bool keepdims, + int ddof /* = 0*/, + StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return var(a, axes, keepdims, ddof, to_stream(s)); +} + +array var( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + int ddof /* = 0*/, + StreamOrDevice s /* = {}*/) { + auto nelements = compute_number_of_elements(a, axes); + auto dtype = at_least_float(a.dtype()); + auto mu = mean(a, axes, true, s); + auto S = sum(square(subtract(a, mu, s), s), axes, keepdims, s); + return multiply(S, array(1.0 / (nelements - ddof), dtype), s); +} + +array var( + const array& a, + int axis, + bool keepdims /* = false */, + int ddof /* = 0*/, + StreamOrDevice s /* = {} */) { + return var(a, std::vector{axis}, keepdims, ddof, to_stream(s)); +} + +array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return prod(a, axes, keepdims, s); +} + +array prod( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + StreamOrDevice s /* = {}*/) { + if (axes.empty()) { + return a; + } + auto [out_shape, sorted_axes] = + compute_reduce_shape(axes, a.shape(), keepdims); + return array( + out_shape, + a.dtype(), + std::make_unique(to_stream(s), Reduce::Prod, sorted_axes), + {a}); +} + +array prod( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + return prod(a, std::vector{axis}, keepdims, s); +} + +array max(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return max(a, axes, keepdims, s); +} + +array max( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + StreamOrDevice s /* = {}*/) { + if (a.size() == 0) { + throw std::invalid_argument("[max] Cannot max reduce zero size array."); + } + if (axes.empty()) { + return a; + } + auto [out_shape, sorted_axes] = + compute_reduce_shape(axes, a.shape(), keepdims); + return array( + out_shape, + a.dtype(), + std::make_unique(to_stream(s), Reduce::Max, sorted_axes), + {a}); +} + +array max( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + return max(a, std::vector{axis}, keepdims, s); +} + +array min(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return min(a, axes, keepdims, s); +} + +array min( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + StreamOrDevice s /* = {}*/) { + if (a.size() == 0) { + throw std::invalid_argument("[min] Cannot min reduce zero size array."); + } + if (axes.empty()) { + return a; + } + auto [out_shape, sorted_axes] = + compute_reduce_shape(axes, a.shape(), keepdims); + return array( + out_shape, + a.dtype(), + std::make_unique(to_stream(s), Reduce::Min, sorted_axes), + {a}); +} + +array min( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + return min(a, std::vector{axis}, keepdims, s); +} + +array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { + int size = a.size(); + auto result = argmin(reshape(a, {size}, s), 0, false, s); + if (keepdims) { + result = reshape(result, std::vector(a.shape().size(), 1), s); + } + return result; +} + +array argmin( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + if (a.size() == 0) { + throw std::invalid_argument( + "[argmin] Cannot argmin reduce zero size array."); + } + auto [out_shape, sorted_axes] = + compute_reduce_shape({axis}, a.shape(), keepdims); + return array( + out_shape, + uint32, + std::make_unique( + to_stream(s), ArgReduce::ArgMin, sorted_axes[0]), + {a}); +} + +array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { + int size = a.size(); + auto result = argmax(reshape(a, {size}, s), 0, false, s); + if (keepdims) { + result = reshape(result, std::vector(a.shape().size(), 1), s); + } + return result; +} + +array argmax( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + if (a.size() == 0) { + throw std::invalid_argument( + "[argmax] Cannot argmax reduce zero size array."); + } + auto [out_shape, sorted_axes] = + compute_reduce_shape({axis}, a.shape(), keepdims); + return array( + out_shape, + uint32, + std::make_unique( + to_stream(s), ArgReduce::ArgMax, sorted_axes[0]), + {a}); +} + +/** Returns a sorted copy of the flattened array. */ +array sort(const array& a, StreamOrDevice s /* = {} */) { + int size = a.size(); + return sort(reshape(a, {size}, s), 0, s); +} + +/** Returns a sorted copy of the array along a given axis. */ +array sort(const array& a, int axis, StreamOrDevice s /* = {} */) { + // Check for valid axis + if (axis + static_cast(a.ndim()) < 0 || + axis >= static_cast(a.ndim())) { + std::ostringstream msg; + msg << "[sort] Received invalid axis " << axis << " for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + // TODO: Fix GPU kernel + if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) { + std::ostringstream msg; + msg << "[sort] GPU sort cannot handle sort axis of >= 2M elements," + << " got array with sort axis size " << a.shape(axis) << "." + << " Please place this operation on the CPU instead."; + throw std::runtime_error(msg.str()); + } + + return array( + a.shape(), a.dtype(), std::make_unique(to_stream(s), axis), {a}); +} + +/** Returns indices that sort the flattened array. */ +array argsort(const array& a, StreamOrDevice s /* = {} */) { + int size = a.size(); + return argsort(reshape(a, {size}, s), 0, s); +} + +/** Returns indices that sort the array along a given axis. */ +array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) { + // Check for valid axis + if (axis + static_cast(a.ndim()) < 0 || + axis >= static_cast(a.ndim())) { + std::ostringstream msg; + msg << "[argsort] Received invalid axis " << axis << " for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + // TODO: Fix GPU kernel + if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) { + std::ostringstream msg; + msg << "[argsort] GPU sort cannot handle sort axis of >= 2M elements," + << " got array with sort axis size " << a.shape(axis) << "." + << " Please place this operation on the CPU instead."; + throw std::runtime_error(msg.str()); + } + + return array( + a.shape(), uint32, std::make_unique(to_stream(s), axis), {a}); +} + +/** + * Returns a partitioned copy of the flattened array + * such that the smaller kth elements are first. + **/ +array partition(const array& a, int kth, StreamOrDevice s /* = {} */) { + int size = a.size(); + return partition(reshape(a, {size}, s), kth, 0, s); +} + +/** + * Returns a partitioned copy of the array along a given axis + * such that the smaller kth elements are first. + **/ +array partition( + const array& a, + int kth, + int axis, + StreamOrDevice s /* = {} */) { + // Check for valid axis + if (axis + static_cast(a.ndim()) < 0 || + axis >= static_cast(a.ndim())) { + std::ostringstream msg; + msg << "[partition] Received invalid axis " << axis << " for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + int axis_ = axis < 0 ? axis + a.ndim() : axis; + int kth_ = kth < 0 ? kth + a.shape(axis) : kth; + if (kth_ < 0 || kth_ >= a.shape(axis_)) { + std::ostringstream msg; + msg << "[partition] Received invalid kth " << kth << "along axis " << axis + << " for array with shape: " << a.shape(); + throw std::invalid_argument(msg.str()); + } + return array( + a.shape(), + a.dtype(), + std::make_unique(to_stream(s), kth_, axis_), + {a}); +} + +/** + * Returns indices that partition the flattened array + * such that the smaller kth elements are first. + **/ +array argpartition(const array& a, int kth, StreamOrDevice s /* = {} */) { + int size = a.size(); + return argpartition(reshape(a, {size}, s), kth, 0, s); +} + +/** + * Returns indices that partition the array along a given axis + * such that the smaller kth elements are first. + **/ +array argpartition( + const array& a, + int kth, + int axis, + StreamOrDevice s /* = {} */) { + // Check for valid axis + if (axis + static_cast(a.ndim()) < 0 || + axis >= static_cast(a.ndim())) { + std::ostringstream msg; + msg << "[argpartition] Received invalid axis " << axis << " for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + int axis_ = axis < 0 ? axis + a.ndim() : axis; + int kth_ = kth < 0 ? kth + a.shape(axis) : kth; + if (kth_ < 0 || kth_ >= a.shape(axis_)) { + std::ostringstream msg; + msg << "[argpartition] Received invalid kth " << kth << "along axis " + << axis << " for array with shape: " << a.shape(); + throw std::invalid_argument(msg.str()); + } + return array( + a.shape(), + uint32, + std::make_unique(to_stream(s), kth_, axis_), + {a}); +} + +/** Returns topk elements of the flattened array. */ +array topk(const array& a, int k, StreamOrDevice s /* = {}*/) { + int size = a.size(); + return topk(reshape(a, {size}, s), k, 0, s); +} + +/** Returns topk elements of the array along a given axis. */ +array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) { + // Check for valid axis + int axis_ = axis < 0 ? axis + a.ndim() : axis; + int kth_ = k < 0 ? k + a.shape(axis) : k; + if (axis_ < 0 || axis_ >= static_cast(a.ndim())) { + std::ostringstream msg; + msg << "[topk] Received invalid axis " << axis << " for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (kth_ < 0 || kth_ >= a.shape(axis_)) { + std::ostringstream msg; + msg << "[topk] Received invalid k " << k << "along axis " << axis + << " for array with shape: " << a.shape(); + throw std::invalid_argument(msg.str()); + } + + array a_partitioned = partition(a, kth_, axis_, s); + std::vector slice_starts(a.ndim(), 0); + std::vector slice_ends = a.shape(); + slice_starts[axis_] = kth_; + return slice(a_partitioned, slice_starts, slice_ends, s); +} + +array logsumexp(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return logsumexp(a, axes, keepdims, s); +} + +array logsumexp( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + StreamOrDevice s /* = {}*/) { + auto maxval = stop_gradient(max(a, axes, true, s)); + auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s); + return add(out, reshape(maxval, out.shape(), s), s); +} + +array logsumexp( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + return logsumexp(a, std::vector{axis}, keepdims, s); +} + +array abs(const array& a, StreamOrDevice s /* = {} */) { + auto out = + array(a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + if (a.dtype() == complex64) { + out = astype(out, float32, s); + } + return out; +} + +array negative(const array& a, StreamOrDevice s /* = {} */) { + if (a.dtype() == bool_) { + auto msg = "[negative] Not supported for bool, use logical_not instead."; + throw std::invalid_argument(msg); + } + return array( + a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); +} +array operator-(const array& a) { + return negative(a); +} + +array sign(const array& a, StreamOrDevice s /* = {} */) { + return array(a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); +} + +array logical_not(const array& a, StreamOrDevice s /* = {} */) { + return array( + a.shape(), + bool_, + std::make_unique(to_stream(s)), + {astype(a, bool_, s)}); +} + +array reciprocal(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return divide(array(1.0f, dtype), a, to_stream(s)); +} + +array add(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto out_type = promote_types(a.dtype(), b.dtype()); + auto inputs = + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + return array( + inputs[0].shape(), out_type, std::make_unique(to_stream(s)), inputs); +} + +array operator+(const array& a, const array& b) { + return add(a, b); +} + +array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto out_type = promote_types(a.dtype(), b.dtype()); + auto inputs = + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + return array( + inputs[0].shape(), + out_type, + std::make_unique(to_stream(s)), + inputs); +} + +array operator-(const array& a, const array& b) { + return subtract(a, b); +} + +array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto out_type = promote_types(a.dtype(), b.dtype()); + auto inputs = + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + return array( + inputs[0].shape(), + out_type, + std::make_unique(to_stream(s)), + inputs); +} + +array operator*(const array& a, const array& b) { + return multiply(a, b); +} + +array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(promote_types(a.dtype(), b.dtype())); + auto inputs = broadcast_arrays( + {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); + return array( + inputs[0].shape(), dtype, std::make_unique(to_stream(s)), inputs); +} +array operator/(const array& a, const array& b) { + return divide(a, b); +} +array operator/(double a, const array& b) { + return divide(array(a), b); +} +array operator/(const array& a, double b) { + return divide(a, array(b)); +} + +array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto out_type = promote_types(a.dtype(), b.dtype()); + auto inputs = + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + return array( + inputs[0].shape(), + out_type, + std::make_unique(to_stream(s)), + inputs); +} + +array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto out_type = promote_types(a.dtype(), b.dtype()); + auto inputs = + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + return array( + inputs[0].shape(), + out_type, + std::make_unique(to_stream(s)), + inputs); +} + +array square(const array& a, StreamOrDevice s /* = {} */) { + return array( + a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); +} + +array exp(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array sin(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array cos(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array tan(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array arcsin(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array arccos(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array arctan(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array sinh(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array cosh(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array tanh(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array arcsinh(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array arccosh(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array arctanh(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array log(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), + dtype, + std::make_unique(to_stream(s), Log::Base::e), + {input}); +} + +array log2(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), + dtype, + std::make_unique(to_stream(s), Log::Base::two), + {input}); +} + +array log10(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), + dtype, + std::make_unique(to_stream(s), Log::Base::ten), + {input}); +} + +array log1p(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) { + // Make sure out type is floating point + auto out_type = at_least_float(promote_types(a.dtype(), b.dtype())); + auto inputs = + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + return array( + inputs[0].shape(), + out_type, + std::make_unique(to_stream(s)), + inputs); +} + +array sigmoid(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), dtype, std::make_unique(to_stream(s)), {input}); +} + +array erf(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_unique(to_stream(s)), + {astype(a, dtype, s)}); +} + +array erfinv(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_unique(to_stream(s)), + {astype(a, dtype, s)}); +} + +array stop_gradient(const array& a, StreamOrDevice s /* = {} */) { + return array( + a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); +} + +array matmul( + const array& in_a, + const array& in_b, + StreamOrDevice s /* = {} */) { + auto a = in_a; + auto b = in_b; + if (a.ndim() == 0 || b.ndim() == 0) { + throw std::invalid_argument( + "[matmul] Got 0 dimension input. Inputs must " + "have at least one dimension."); + } + if (a.ndim() == 1) { + // Insert a singleton dim in the beginning + a = reshape(a, {1, -1}, s); + } + if (b.ndim() == 1) { + // Insert a singleton dim at the end + b = reshape(b, {-1, 1}, s); + } + if (a.shape(-1) != b.shape(-2)) { + std::ostringstream msg; + msg << "[matmul] Last dimension of first input with shape " << a.shape() + << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + // Type promotion + auto out_type = promote_types(a.dtype(), b.dtype()); + if (a.dtype() != out_type) { + a = astype(a, out_type, s); + } + if (b.dtype() != out_type) { + b = astype(b, out_type, s); + } + + // We can batch the multiplication by reshaping a + if (a.ndim() > 2 && b.ndim() == 2) { + std::vector out_shape = a.shape(); + a = reshape(a, {-1, out_shape.back()}, s); + out_shape.back() = b.shape(-1); + if (in_b.ndim() == 1) { + out_shape.pop_back(); + } + auto out = array( + {a.shape(0), b.shape(1)}, + out_type, + std::make_unique(to_stream(s)), + {a, b}); + return reshape(out, out_shape, s); + } + + if (a.ndim() > 2 || b.ndim() > 2) { + std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); + std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); + auto inner_shape = broadcast_shapes(bsx_a, bsx_b); + + // Broadcast a + inner_shape.push_back(a.shape(-2)); + inner_shape.push_back(a.shape(-1)); + a = broadcast_to(a, inner_shape, s); + + // Broadcast b + *(inner_shape.end() - 2) = b.shape(-2); + *(inner_shape.end() - 1) = b.shape(-1); + b = broadcast_to(b, inner_shape, s); + } + + auto out_shape = a.shape(); + out_shape.back() = b.shape(-1); + + auto out = array( + out_shape, out_type, std::make_unique(to_stream(s)), {a, b}); + + // Remove the possibly inserted singleton dimensions + if (in_a.ndim() == 1 || in_b.ndim() == 1) { + out_shape.erase( + out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1), + out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1)); + out = reshape(out, out_shape, s); + } + return out; +} + +array gather( + const array& a, + const std::vector& indices, + const std::vector& axes, + const std::vector& slice_sizes, + StreamOrDevice s /* = {} */) { + // Checks that indices, dimensions, and slice_sizes are all valid + if (indices.size() > a.ndim()) { + std::ostringstream msg; + msg << "[gather] Too many index arrays. Got " << indices.size() + << " index arrays for input with " << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + std::set dims(axes.begin(), axes.end()); + if (dims.size() != axes.size()) { + throw std::invalid_argument("[gather] Repeat axes not allowed in gather."); + } + if (!dims.empty() && (*dims.begin() < 0 || *dims.rbegin() >= a.ndim())) { + throw std::invalid_argument("[gather] Axes don't match array dimensions."); + } + if (indices.size() != axes.size()) { + throw std::invalid_argument( + "[gather] Number of index arrays does not match number of axes."); + } + for (auto& x : indices) { + if (x.dtype() == bool_) { + throw("[Gather] Boolean indices not supported."); + } + } + + if (slice_sizes.size() != a.ndim()) { + std::ostringstream msg; + msg << "[gather] Got slice_sizes with size " << slice_sizes.size() + << " for array with " << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + for (int i = 0; i < a.ndim(); ++i) { + if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) { + std::ostringstream msg; + msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got " + << slice_sizes << " for array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } + } + + // Promote indices to the same type + auto dtype = result_type(indices); + if (!is_integral(dtype)) { + throw std::invalid_argument( + "[gather] Got indices with invalid dtype. Indices must be integral."); + } + + // Broadcast and cast indices if necessary + auto inputs = broadcast_arrays(indices); + for (auto& idx : inputs) { + idx = astype(idx, dtype, s); + } + + std::vector out_shape; + if (!inputs.empty()) { + out_shape = inputs[0].shape(); + } + out_shape.insert(out_shape.end(), slice_sizes.begin(), slice_sizes.end()); + + inputs.insert(inputs.begin(), a); + return array( + out_shape, + a.dtype(), + std::make_unique(to_stream(s), axes, slice_sizes), + inputs); +} + +array take( + const array& a, + const array& indices, + int axis, + StreamOrDevice s /* = {} */) { + // Check for valid axis + if (axis + static_cast(a.ndim()) < 0 || + axis >= static_cast(a.ndim())) { + std::ostringstream msg; + msg << "[take] Received invalid axis " << axis << " for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + // Check for valid take + if (a.size() == 0 && indices.size() != 0) { + throw std::invalid_argument( + "[take] Cannot do a non-empty take from an array with zero elements."); + } + + // Handle negative axis + axis = axis < 0 ? a.ndim() + axis : axis; + + // Make slice sizes to pass to gather + std::vector slice_sizes = a.shape(); + slice_sizes[axis] = indices.size() > 0 ? 1 : 0; + + auto out = gather(a, indices, axis, slice_sizes, s); + + // Transpose indices dimensions to axis dimension + if (axis != 0) { + std::vector t_axes(out.ndim()); + std::iota(t_axes.begin(), t_axes.begin() + axis, indices.ndim()); + std::iota(t_axes.begin() + axis, t_axes.begin() + axis + indices.ndim(), 0); + std::iota( + t_axes.begin() + axis + indices.ndim(), + t_axes.end(), + indices.ndim() + axis); + out = transpose(out, t_axes, s); + } + + // Squeeze the axis we take over + std::vector out_shape = out.shape(); + out_shape.erase(out_shape.begin() + indices.ndim() + axis); + return reshape(out, out_shape, s); +} + +array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) { + return take(reshape(a, {-1}, s), indices, 0, s); +} + +array take_along_axis( + const array& a, + const array& indices, + int axis, + StreamOrDevice s /* = {} */) { + if (axis + a.ndim() < 0 || axis >= static_cast(a.ndim())) { + std::ostringstream msg; + msg << "[take_along_axis] Received invalid axis " + << " for array with " << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (indices.ndim() != a.ndim()) { + std::ostringstream msg; + msg << "[take_along_axis] Indices of dimension " << indices.ndim() + << " does not match array of dimension " << a.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + // Allow negative axis + axis = axis < 0 ? a.ndim() + axis : axis; + + std::vector nd_indices; + std::vector index_shape(a.ndim(), 1); + for (int i = 0; i < a.ndim(); ++i) { + if (i == axis) { + nd_indices.push_back(indices); + } else { + // Reshape so they can be broadcast + index_shape[i] = a.shape(i); + nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s)); + index_shape[i] = 1; + } + } + std::vector dims(a.ndim()); + std::iota(dims.begin(), dims.end(), 0); + std::vector slice_sizes(a.ndim(), a.size() > 0); + auto out = gather(a, nd_indices, dims, slice_sizes, s); + + // Squeeze out the slice shape + std::vector out_shape( + out.shape().begin(), out.shape().begin() + a.ndim()); + return reshape(out, out_shape, s); +} + +/** Scatter updates to given indices */ +array scatter( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + Scatter::ReduceType mode /*= Scatter::ReduceType::None*/, + StreamOrDevice s /*= {}*/) { + // Checks that indices, dimensions, and slice_sizes are all valid + if (indices.size() > a.ndim()) { + std::ostringstream msg; + msg << "[scatter] Too many index arrays. Got " << indices.size() + << " index arrays for input with " << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + for (auto& x : indices) { + if (x.dtype() == bool_) { + throw("[scatter] Boolean indices not supported."); + } + } + + std::set dims(axes.begin(), axes.end()); + if (dims.size() != axes.size()) { + throw std::invalid_argument( + "[scatter] Repeat axes not allowed in scatter."); + } + if (!dims.empty() && (*dims.begin() < 0 || *dims.rbegin() >= a.ndim())) { + throw std::invalid_argument("[scatter] Axes don't match array dimensions."); + } + if (indices.size() != axes.size()) { + throw std::invalid_argument( + "[scatter] Number of index arrays does not match number of axes."); + } + + // Broadcast and cast indices if necessary + auto inputs = broadcast_arrays(indices); + + std::vector idx_shape; + if (!inputs.empty()) { + idx_shape = inputs[0].shape(); + } + + if (updates.ndim() != (a.ndim() + idx_shape.size())) { + std::ostringstream msg; + msg << "[scatter] Updates with " << updates.ndim() + << " dimensions does not match the sum of the array and indices " + "dimensions " + << a.ndim() + idx_shape.size() << "."; + throw std::invalid_argument(msg.str()); + } + for (int i = 0; i < idx_shape.size(); ++i) { + if (updates.shape(i) != idx_shape[i]) { + std::ostringstream msg; + msg << "[scatter] Update shape " << updates.shape() + << " is not valid for broadcasted index shape " << idx_shape << "."; + throw std::invalid_argument(msg.str()); + } + } + for (int i = 0; i < a.ndim(); ++i) { + auto up_shape = updates.shape(i + idx_shape.size()); + if (up_shape > a.shape(i)) { + std::ostringstream msg; + msg << "[scatter] Updates with shape " << updates.shape() + << " are too large for array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } + } + + // Promote indices to the same type + auto dtype = result_type(indices); + if (!is_integral(dtype)) { + throw std::invalid_argument( + "[scatter] Got indices with invalid dtype. Indices must be integral."); + } + for (auto& idx : inputs) { + idx = astype(idx, dtype, s); + } + + inputs.insert(inputs.begin(), a); + // TODO promote or cast? + inputs.push_back(astype(updates, a.dtype(), s)); + return array( + a.shape(), + a.dtype(), + std::make_unique(to_stream(s), mode, axes), + inputs); +} + +array scatter( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s /*= {}*/) { + return scatter(a, indices, updates, axes, Scatter::None, s); +} + +array scatter_add( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s /*= {}*/) { + return scatter(a, indices, updates, axes, Scatter::Sum, s); +} + +array scatter_prod( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s /*= {}*/) { + return scatter(a, indices, updates, axes, Scatter::Prod, s); +} + +array scatter_max( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s /*= {}*/) { + return scatter(a, indices, updates, axes, Scatter::Max, s); +} + +array scatter_min( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s /*= {}*/) { + return scatter(a, indices, updates, axes, Scatter::Min, s); +} + +array sqrt(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_unique(to_stream(s)), + {astype(a, dtype, s)}); +} + +array rsqrt(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_unique(to_stream(s), true), + {astype(a, dtype, s)}); +} + +array softmax( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {}*/) { + if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_unique(to_stream(s)), + {astype(a, dtype, s)}); + } else { + auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s); + auto ex = exp(subtract(a, a_max, s), s); + return divide(ex, sum(ex, axes, /*keepdims = */ true, s), s); + } +} + +array softmax(const array& a, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return softmax(a, axes, s); +} + +array power(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; + if (a.shape() != b.shape()) { + inputs = broadcast_arrays(inputs, s); + } + return array( + inputs[0].shape(), dtype, std::make_unique(to_stream(s)), inputs); +} + +array cumsum( + const array& a, + int axis, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + int ndim = a.ndim(); + if (axis >= ndim || axis < -ndim) { + std::ostringstream msg; + msg << "[cumsum] Axis " << axis << " is out of bounds for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + axis = (axis + a.ndim()) % a.ndim(); + auto out_type = a.dtype() == bool_ ? int32 : a.dtype(); + return array( + a.shape(), + out_type, + std::make_unique( + to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive), + {a}); +} + +array cumprod( + const array& a, + int axis, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + int ndim = a.ndim(); + if (axis >= ndim || axis < -ndim) { + std::ostringstream msg; + msg << "[cumprod] Axis " << axis << " is out of bounds for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + axis = (axis + a.ndim()) % a.ndim(); + return array( + a.shape(), + a.dtype(), + std::make_unique( + to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive), + {a}); +} + +array cummax( + const array& a, + int axis, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + int ndim = a.ndim(); + if (axis >= ndim || axis < -ndim) { + std::ostringstream msg; + msg << "[cummax] Axis " << axis << " is out of bounds for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + axis = (axis + a.ndim()) % a.ndim(); + return array( + a.shape(), + a.dtype(), + std::make_unique( + to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive), + {a}); +} + +array cummin( + const array& a, + int axis, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + int ndim = a.ndim(); + if (axis >= ndim || axis < -ndim) { + std::ostringstream msg; + msg << "[cummin] Axis " << axis << " is out of bounds for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + axis = (axis + a.ndim()) % a.ndim(); + return array( + a.shape(), + a.dtype(), + std::make_unique( + to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive), + {a}); +} + +/** Convolution operations */ + +namespace { + +// Conv helpers +inline int conv_out_axis_size( + int in_dim, + int wt_dim, + int stride, + int padding, + int dilation) { + int ker = dilation * (wt_dim - 1); + return ((in_dim + 2 * padding - ker - 1) / stride) + 1; +} + +inline std::vector conv_out_shape( + const std::vector& in_shape, + const std::vector& wt_shape, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilation) { + int N = in_shape[0]; + int O = wt_shape[0]; + std::vector out_shape(in_shape.size()); + int i = 0; + out_shape[i++] = N; + for (; i < in_shape.size() - 1; i++) { + out_shape[i] = conv_out_axis_size( + in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]); + } + out_shape[i] = O; + + return out_shape; +} + +inline void run_conv_checks(const array& in, const array& wt, int n_dim) { + if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) { + std::ostringstream msg; + msg << "[conv] Invalid input array with type " << in.dtype() << "." + << " Convolution currently only supports floating point types"; + throw std::invalid_argument(msg.str()); + } + + if (in.ndim() != n_dim + 2) { + std::ostringstream msg; + msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for " + << n_dim << "D convolution." + << " Expected an array with " << n_dim + 2 + << " dimensions following the format [N, ..., C_in]."; + throw std::invalid_argument(msg.str()); + } + + if (wt.ndim() != n_dim + 2) { + std::ostringstream msg; + msg << "[conv] Invalid weight array with " << wt.ndim() + << " dimensions for " << n_dim << "D convolution." + << " Expected an array with " << n_dim + 2 + << " dimensions following the format [C_out, ..., C_in]."; + throw std::invalid_argument(msg.str()); + } + + if (in.shape(n_dim + 1) != wt.shape(n_dim + 1)) { + std::ostringstream msg; + msg << "[conv] Expect the input channels in the input" + << " and weight array to match but got shapes -" + << " input: " << in.shape() << " and weight: " << wt.shape(); + throw std::invalid_argument(msg.str()); + } +} + +} // namespace + +/** 1D convolution with a filter */ +array conv1d( + const array& in_, + const array& wt_, + int stride /* = 1 */, + int padding /* = 0 */, + int dilation /* = 1 */, + int groups /* = 1 */, + StreamOrDevice s /* = {} */) { + // Run checks + if (groups != 1) { + throw std::invalid_argument("[conv1d] Cannot handle groups != 1 yet"); + } + if (dilation != 1) { + throw std::invalid_argument("[conv1d] Cannot handle dilation != 1 yet"); + } + + // Run checks + run_conv_checks(in_, wt_, 1); + + auto in = in_; + auto wt = wt_; + + // Type promotion + auto out_type = promote_types(in.dtype(), wt.dtype()); + in = astype(in, out_type, s); + wt = astype(wt, out_type, s); + + std::vector strides_vec = {stride}; + std::vector padding_vec = {padding}; + std::vector dilation_vec = {dilation}; + + // Get output shapes + std::vector out_shape = conv_out_shape( + in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec); + + return array( + out_shape, + in.dtype(), + std::make_unique( + to_stream(s), + padding_vec, + strides_vec, + dilation_vec, + std::vector(1, 1)), + {in, wt}); +} + +/** 2D convolution with a filter */ +array conv2d( + const array& in_, + const array& wt_, + const std::pair& stride /* = {1, 1} */, + const std::pair& padding /* = {0, 0} */, + const std::pair& dilation /* = {1, 1} */, + int groups /* = 1 */, + StreamOrDevice s /* = {} */) { + // Run checks + if (groups != 1) { + throw std::invalid_argument("[conv2d] Cannot handle groups != 1 yet"); + } + if (dilation.first != 1 || dilation.second != 1) { + throw std::invalid_argument("[conv1d] Cannot handle dilation != 1 yet"); + } + + // Run checks + run_conv_checks(in_, wt_, 2); + + auto in = in_; + auto wt = wt_; + + // Type promotion + auto out_type = promote_types(in.dtype(), wt.dtype()); + in = astype(in, out_type, s); + wt = astype(wt, out_type, s); + + std::vector strides_vec = {stride.first, stride.second}; + std::vector padding_vec = {padding.first, padding.second}; + std::vector dilation_vec = {dilation.first, dilation.second}; + + // Get output shapes + std::vector out_shape = conv_out_shape( + in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec); + + return array( + out_shape, + in.dtype(), + std::make_unique( + to_stream(s), + padding_vec, + strides_vec, + dilation_vec, + std::vector(2, 1)), + {in, wt}); +} + +} // namespace mlx::core diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h new file mode 100644 index 000000000..e97d41d73 --- /dev/null +++ b/mlx/transforms_impl.h @@ -0,0 +1,16 @@ + +namespace mlx::core::detail { + +std::pair, std::vector> vmap_trace( + const std::function(const std::vector&)>& fun, + const std::vector& inputs, + const std::vector& in_axes); + +std::vector vmap_replace( + const std::vector& inputs, + const std::vector& s_inputs, + const std::vector& s_outputs, + const std::vector& in_axes, + const std::vector& out_axes); + +} // namespace mlx::core::detail diff --git a/mlx/types/complex.h b/mlx/types/complex.h new file mode 100644 index 000000000..b7533efba --- /dev/null +++ b/mlx/types/complex.h @@ -0,0 +1,75 @@ +#pragma once +#include +#include "mlx/types/half_types.h" + +namespace mlx::core { + +struct complex64_t; + +template +static constexpr bool can_convert_to_complex64 = + !std::is_same_v && std::is_convertible_v; + +struct complex64_t : public std::complex { + complex64_t(float v, float u) : std::complex(v, u){}; + complex64_t(std::complex v) : std::complex(v){}; + + template < + typename T, + typename = typename std::enable_if>::type> + complex64_t(T x) : std::complex(x){}; + + operator float() const { + return real(); + }; +}; + +inline bool operator>=(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || + (a.real() == b.real() && a.imag() >= b.imag()); +} + +inline bool operator>(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); +} + +inline bool operator<=(const complex64_t& a, const complex64_t& b) { + return operator>=(b, a); +} + +inline bool operator<(const complex64_t& a, const complex64_t& b) { + return operator>(b, a); +} + +inline complex64_t operator-(const complex64_t& v) { + return -static_cast>(v); +} + +// clang-format off +#define complex_binop_helper(_op_, _operator_, itype) \ + inline complex64_t _operator_(itype x, const complex64_t& y) { \ + return x _op_ static_cast>(y); \ + } \ + inline complex64_t _operator_(const complex64_t& x, itype y) { \ + return static_cast>(x) _op_ y; \ + } + +#define complex_binop(_op_, _operator_) \ + inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ + return static_cast>(x) \ + _op_ static_cast>(y); \ + } \ + complex_binop_helper(_op_, _operator_, bool) \ + complex_binop_helper(_op_, _operator_, uint32_t) \ + complex_binop_helper(_op_, _operator_, uint64_t) \ + complex_binop_helper(_op_, _operator_, int32_t) \ + complex_binop_helper(_op_, _operator_, int64_t) \ + complex_binop_helper(_op_, _operator_, float16_t) \ + complex_binop_helper(_op_, _operator_, bfloat16_t) \ + complex_binop_helper(_op_, _operator_, const std::complex&) \ + complex_binop_helper(_op_, _operator_, float) +// clang-format on + +complex_binop(+, operator+) + +} // namespace mlx::core diff --git a/mlx/types/fp16.h b/mlx/types/fp16.h new file mode 100644 index 000000000..e4f94c994 --- /dev/null +++ b/mlx/types/fp16.h @@ -0,0 +1,232 @@ +#pragma once + +#include +#include +#include +#include + +#define __MLX_HALF_NAN__ 0x7D00 + +namespace mlx::core { + +namespace { +union float_bits_fp16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_Float16 { + uint16_t bits_; + + // Default constructor + _MLX_Float16() = default; + + // Default copy constructor + _MLX_Float16(_MLX_Float16 const&) = default; + + // Appease std::vector for being special + _MLX_Float16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_Float16& operator=(const float& x) { + return (*this = _MLX_Float16(x)); + } + + // From float32 + _MLX_Float16(const float& x) : bits_(0) { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 in; + + // Take fp32 bits + in.f = x; + + // Find and take sign bit + uint32_t x_sign_32 = in.u & uint32_t(0x80000000); + uint16_t x_sign_16 = (x_sign_32 >> 16); + + if (std::isnan(x)) { + bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__); + } else { + // Union + float_bits_fp16 inf_scale, zero_scale, magic_bits; + + // Find exponent bits and take the max supported by half + uint32_t x_expo_32 = in.u & uint32_t(0x7f800000); + uint32_t max_expo_32 = uint32_t(0x38800000); + x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32; + x_expo_32 += uint32_t(15) << 23; + + // Handle scaling to inf as needed + inf_scale.u = uint32_t(0x77800000); + zero_scale.u = uint32_t(0x08800000); + + // Combine with magic and let addition do rouding + magic_bits.u = x_expo_32; + magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f; + + // Take the lower 5 bits of the exponent + uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00)); + + // Collect the lower 12 bits which have the mantissa + uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff); + + // Combine sign, exp and mantissa + bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16)); + } + } + + // To float32 + operator float() const { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 out; + + uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000); + uint32_t base = (bits_ << 16); + uint32_t two_base = base + base; + + uint32_t denorm_max = 1u << 27; + if (two_base < denorm_max) { + out.u = uint32_t(126) << 23; // magic mask + out.u |= (two_base >> 17); // Bits from fp16 + out.f -= 0.5f; // magic bias + } else { + out.u = uint32_t(0xE0) << 23; // exponent offset + out.u += (two_base >> 4); // Bits from fp16 + float out_unscaled = out.f; // Store value + out.u = uint32_t(0x7800000); // exponent scale + out.f *= out_unscaled; + } + + // Add sign + out.u |= x_sign_32; + + return out.f; + } +}; + +#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define half_binop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, float, float, float); \ + half_binop_helper(__op__, __operator__, double, double, double); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float); + +half_binop(+, operator+); +half_binop(-, operator-); +half_binop(*, operator*); +half_binop(/, operator/); + +#undef half_binop + +// Comparison ops +#define half_compop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, bool, float, float); \ + half_binop_helper(__op__, __operator__, bool, double, double); \ + half_binop_helper(__op__, __operator__, bool, int32_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + half_binop_helper(__op__, __operator__, bool, int64_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint64_t, float); + +half_compop(>, operator>); +half_compop(<, operator<); +half_compop(>=, operator>=); +half_compop(<=, operator<=); +half_compop(==, operator==); +half_compop(!=, operator!=); + +#undef half_compop + +// Negative +inline _MLX_Float16 operator-(_MLX_Float16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define half_inplace_op(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +half_inplace_op(+, operator+=); +half_inplace_op(-, operator-=); +half_inplace_op(*, operator*=); +half_inplace_op(/, operator/=); + +#undef half_inplace_op + +// Bitwise ops + +#define half_bitop(__op__, __operator__) \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +half_bitop(|, operator|); +half_bitop(&, operator&); +half_bitop(^, operator^); + +#undef half_bitop + +#define half_inplace_bitop(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +half_inplace_bitop(|, operator|=); +half_inplace_bitop(&, operator&=); +half_inplace_bitop(^, operator^=); + +#undef half_inplace_bitop + +} // namespace mlx::core diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..09ead06ee --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"] +build-backend = "setuptools.build_meta" diff --git a/python/mlx/nn/layers/convolution.py b/python/mlx/nn/layers/convolution.py new file mode 100644 index 000000000..1e7ac1c3d --- /dev/null +++ b/python/mlx/nn/layers/convolution.py @@ -0,0 +1,124 @@ +import math +from typing import Union + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +class Conv1d(Module): + """Applies a 1-dimensional convolution over the multi-channel input sequence. + + The channels are expected to be last i.e. the input shape should be ``NLC`` where: + - ``N`` is the batch dimension + - ``L`` is the sequence length + - ``C`` is the number of input channels + + Args: + in_channels (int): The number of input channels + out_channels (int): The number of output channels + kernel_size (int): The size of the convolution filters + stride (int, optional): The stride when applying the filter. + Default: 1. + padding (int, optional): How many positions to 0-pad the input with. + Default: 0. + bias (bool, optional): If ``True`` add a learnable bias to the output. + Default: ``True`` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ): + super().__init__() + + scale = math.sqrt(1 / (in_channels * kernel_size)) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, kernel_size, in_channels), + ) + if bias: + self.bias = mx.zeros((out_channels,)) + + self.padding = padding + self.stride = stride + + def _extra_repr(self): + return ( + f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " + f"padding={self.padding}, bias={'bias' in self}" + ) + + def __call__(self, x): + y = mx.conv1d(x, self.weight, self.stride, self.padding) + if "bias" in self: + y = y + self.bias + return y + + +class Conv2d(Module): + """Applies a 2-dimensional convolution over the multi-channel input image. + + The channels are expected to be last i.e. the input shape should be ``NHWC`` where: + - ``N`` is the batch dimension + - ``H`` is the input image height + - ``W`` is the input image width + - ``C`` is the number of input channels + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int or tuple): The size of the convolution filters. + stride (int or tuple, optional): The size of the stride when + applying the filter. Default: 0. + padding (int or tuple, optional): How many positions to 0-pad + the input with. Default: 0. + bias (bool, optional): If ``True`` add a learnable bias to the + output. Default: ``True`` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, tuple], + stride: Union[int, tuple] = 1, + padding: Union[int, tuple] = 0, + bias: bool = True, + ): + super().__init__() + + kernel_size, stride, padding = map( + lambda x: (x, x) if isinstance(x, int) else x, + (kernel_size, stride, padding), + ) + scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1])) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, *kernel_size, in_channels), + ) + if bias: + self.bias = mx.zeros((out_channels,)) + + self.padding = padding + self.stride = stride + + def _extra_repr(self): + return ( + f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " + f"padding={self.padding}, bias={'bias' in self}" + ) + + def __call__(self, x): + y = mx.conv2d(x, self.weight, self.stride, self.padding) + if "bias" in self: + y = y + self.bias + return y diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py new file mode 100644 index 000000000..d6e9b976a --- /dev/null +++ b/python/mlx/nn/layers/embedding.py @@ -0,0 +1,28 @@ +import math + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +class Embedding(Module): + """Implements a simple lookup table that maps each input integer to a + high-dimensional vector. + + Typically used to embed discrete tokens for processing by neural networks. + + Args: + num_embeddings (int): How many possible discrete tokens can we embed. + Usually called the vocabulary size. + dims (int): The dimensionality of the embeddings. + """ + + def __init__(self, num_embeddings: int, dims: int): + super().__init__() + scale = math.sqrt(1 / dims) + self.weight = mx.random.normal((num_embeddings, dims)) * scale + + def _extra_repr(self): + return f"{self.weight.shape[0]}, {self.weight.shape[1]}" + + def __call__(self, x): + return self.weight[x] diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py new file mode 100644 index 000000000..aa48688d6 --- /dev/null +++ b/python/mlx/nn/layers/linear.py @@ -0,0 +1,34 @@ +import math + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +class Linear(Module): + """Applies an affine transformation to the input. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool): If set to False then the layer will not use a bias + """ + + def __init__(self, input_dims: int, output_dims: int, bias: bool = True): + super().__init__() + scale = math.sqrt(1 / input_dims) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims), + ) + if bias: + self.bias = mx.zeros((output_dims,)) + + def _extra_repr(self): + return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" + + def __call__(self, x): + x = x @ self.weight.T + if "bias" in self: + x = x + self.bias + return x diff --git a/python/src/load.h b/python/src/load.h new file mode 100644 index 000000000..f7f1a4148 --- /dev/null +++ b/python/src/load.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include +#include "mlx/ops.h" + +namespace py = pybind11; +using namespace mlx::core; + +using DictOrArray = std::variant>; + +DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); +void mlx_save_helper(py::object file, array a, bool retain_graph = true); +void mlx_savez_helper( + py::object file, + py::args args, + const py::kwargs& kwargs, + bool compressed = false); \ No newline at end of file diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp new file mode 100644 index 000000000..0e9aeba65 --- /dev/null +++ b/python/src/mlx.cpp @@ -0,0 +1,31 @@ +#include + +#define STRINGIFY(x) #x +#define TOSTRING(x) STRINGIFY(x) + +namespace py = pybind11; + +void init_array(py::module_&); +void init_device(py::module_&); +void init_stream(py::module_&); +void init_metal(py::module_&); +void init_ops(py::module_&); +void init_transforms(py::module_&); +void init_random(py::module_&); +void init_fft(py::module_&); + +PYBIND11_MODULE(core, m) { + m.doc() = "mlx: A framework for machine learning on Apple Silicon."; + + auto reprlib_fix = py::module_::import("mlx._reprlib_fix"); + + init_device(m); + init_stream(m); + init_array(m); + init_metal(m); + init_ops(m); + init_transforms(m); + init_random(m); + init_fft(m); + m.attr("__version__") = TOSTRING(_VERSION_); +} diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp new file mode 100644 index 000000000..b257ab40f --- /dev/null +++ b/python/src/transforms.cpp @@ -0,0 +1,723 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/graph_utils.h" +#include "mlx/transforms.h" +#include "mlx/transforms_impl.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +using IntOrVec = std::variant>; +using StrOrVec = std::variant>; + +template +std::vector to_vector(const std::variant>& v) { + std::vector vals; + if (auto pv = std::get_if(&v); pv) { + vals.push_back(*pv); + } else { + vals = std::get>(v); + } + return vals; +} + +void tree_visit(py::object tree, std::function visitor) { + std::function recurse; + recurse = [&](py::handle subtree) { + if (py::isinstance(subtree) || + py::isinstance(subtree)) { + for (auto item : subtree) { + recurse(item); + } + } else if (py::isinstance(subtree)) { + for (auto item : py::cast(subtree)) { + recurse(item.second); + } + } else { + visitor(subtree); + } + }; + + recurse(tree); +} + +template +void validate_subtrees(const std::vector& subtrees) { + int len = py::cast(subtrees[0]).size(); + for (auto& subtree : subtrees) { + if ((py::isinstance(subtree) && py::cast(subtree).size() != len) || + py::isinstance(subtree) || py::isinstance(subtree)) { + throw std::invalid_argument( + "[tree_map] Additional input tree is not a valid prefix of the first tree."); + } + } +} + +py::object tree_map( + const std::vector& trees, + std::function&)> transform) { + std::function&)> recurse; + + recurse = [&](const std::vector& subtrees) { + if (py::isinstance(subtrees[0])) { + py::list l; + std::vector items(subtrees.size()); + validate_subtrees(subtrees); + for (int i = 0; i < py::cast(subtrees[0]).size(); ++i) { + for (int j = 0; j < subtrees.size(); ++j) { + if (py::isinstance(subtrees[j])) { + items[j] = py::cast(subtrees[j])[i]; + } else { + items[j] = subtrees[j]; + } + } + l.append(recurse(items)); + } + return py::cast(l); + } else if (py::isinstance(subtrees[0])) { + // Check the rest of the subtrees + std::vector items(subtrees.size()); + int len = py::cast(subtrees[0]).size(); + py::tuple l(len); + validate_subtrees(subtrees); + for (int i = 0; i < len; ++i) { + for (int j = 0; j < subtrees.size(); ++j) { + if (py::isinstance(subtrees[j])) { + items[j] = py::cast(subtrees[j])[i]; + } else { + items[j] = subtrees[j]; + } + } + l[i] = recurse(items); + } + return py::cast(l); + } else if (py::isinstance(subtrees[0])) { + std::vector items(subtrees.size()); + validate_subtrees(subtrees); + py::dict d; + for (auto item : py::cast(subtrees[0])) { + for (int j = 0; j < subtrees.size(); ++j) { + if (py::isinstance(subtrees[j])) { + auto subdict = py::cast(subtrees[j]); + if (!subdict.contains(item.first)) { + throw std::invalid_argument( + "[tree_map] Tree is not a valid prefix tree of the first tree."); + } + items[j] = subdict[item.first]; + } else { + items[j] = subtrees[j]; + } + } + d[item.first] = recurse(items); + } + return py::cast(d); + } else { + return transform(subtrees); + } + }; + return recurse(trees); +} + +py::object tree_map( + py::object tree, + std::function transform) { + return tree_map({tree}, [&](std::vector inputs) { + return transform(inputs[0]); + }); +} + +std::vector tree_flatten(py::object tree, bool strict = true) { + std::vector flat_tree; + + tree_visit(tree, [&](py::handle obj) { + if (py::isinstance(obj)) { + flat_tree.push_back(py::cast(obj)); + } else if (strict) { + throw std::invalid_argument("Argument is not an array"); + } + }); + + return flat_tree; +} + +py::object tree_unflatten( + py::object tree, + const std::vector& values, + int index = 0) { + return tree_map(tree, [&](py::handle obj) { + if (py::isinstance(obj)) { + return py::cast(values[index++]); + } else { + return py::cast(obj); + } + }); +} + +auto validate_argnums_argnames( + const std::optional& argnums, + const StrOrVec& argnames) { + auto vec_names = to_vector(argnames); + + if (!argnums.has_value()) { + // argnums was not provided and argnames was empty + if (vec_names.empty()) { + return std::make_pair(std::vector{0}, vec_names); + } else { + return std::make_pair(std::vector{}, vec_names); + } + } + + return std::make_pair(to_vector(*argnums), vec_names); +} + +auto py_value_and_grad( + const py::function& fun, + std::vector argnums, + std::vector argnames, + const std::string& error_msg_tag, + bool scalar_func_only) { + // Sanitize argnums + if (argnums.size() == 0 && argnames.size() == 0) { + throw std::invalid_argument( + error_msg_tag + " Gradient wrt no argument requested"); + } + if (argnums.size() > 0) { + std::sort(argnums.begin(), argnums.end()); + if (argnums[0] < 0) { + std::ostringstream msg; + msg << error_msg_tag + << " Can't compute the gradient of negative argument index " + << argnums[0]; + throw std::invalid_argument(msg.str()); + } + } + + return [fun, argnums, argnames, error_msg_tag, scalar_func_only]( + const py::args& args, const py::kwargs& kwargs) { + // Sanitize the input + if (argnums.size() > 0 && argnums.back() >= args.size()) { + std::ostringstream msg; + msg << error_msg_tag << " Can't compute the gradient of argument index " + << argnums.back() << " because the function is called with only " + << args.size() << " arguments."; + throw std::invalid_argument(msg.str()); + } + + for (auto& key : argnames) { + if (!kwargs.contains(key)) { + std::ostringstream msg; + msg << error_msg_tag + << " Can't compute the gradient of keyword argument '" << key + << "' because the function is called with the " + << "following keyword arguments {"; + for (auto item : kwargs) { + msg << item.first.cast() << ","; + } + msg << "}"; + throw std::invalid_argument(msg.str()); + } + } + + // Collect the arrays + std::vector arrays; + std::vector counts(1, 0); + for (auto i : argnums) { + auto argsi = tree_flatten(args[i]); + arrays.insert(arrays.end(), argsi.begin(), argsi.end()); + counts.push_back(argsi.size()); + } + for (auto& key : argnames) { + auto argsk = tree_flatten(kwargs[key.c_str()]); + arrays.insert(arrays.end(), argsk.begin(), argsk.end()); + counts.push_back(argsk.size()); + } + std::partial_sum(counts.cbegin(), counts.cend(), counts.begin()); + std::vector gradient_indices(arrays.size()); + std::iota(gradient_indices.begin(), gradient_indices.end(), 0); + + // value_out will hold the output of the python function in order to be + // able to reconstruct the python tree of extra return values + py::object py_value_out; + auto value_and_grads = value_and_grad( + [&fun, + &args, + &kwargs, + &argnums, + &argnames, + &counts, + &py_value_out, + &error_msg_tag, + scalar_func_only](const std::vector& a) { + // Copy the arguments + py::args args_cpy = py::tuple(args.size()); + py::kwargs kwargs_cpy = py::kwargs(); + int j = 0; + for (int i = 0; i < args.size(); ++i) { + if (j < argnums.size() && i == argnums[j]) { + args_cpy[i] = tree_unflatten(args[i], a, counts[j]); + j++; + } else { + args_cpy[i] = args[i]; + } + } + for (auto& key : argnames) { + kwargs_cpy[key.c_str()] = + tree_unflatten(kwargs[key.c_str()], a, counts[j]); + j++; + } + for (auto item : kwargs) { + if (kwargs_cpy.contains(item.first)) { + continue; + } + kwargs_cpy[item.first] = item.second; + } + + // Call the python function + py_value_out = fun(*args_cpy, **kwargs_cpy); + + // Validate the return value of the python function + if (!py::isinstance(py_value_out)) { + if (scalar_func_only) { + std::ostringstream msg; + msg << error_msg_tag << " The return value of the function " + << "whose gradient we want to compute should be a " + << "scalar array; but " << py_value_out.get_type() + << " was returned."; + throw std::invalid_argument(msg.str()); + } + if (!py::isinstance(py_value_out)) { + std::ostringstream msg; + msg << error_msg_tag << " The return value of the function " + << "whose gradient we want to compute should be either a " + << "scalar array or a tuple with the first value being a " + << "scalar array (Union[array, Tuple[array, Any, ...]]); but " + << py_value_out.get_type() << " was returned."; + throw std::invalid_argument(msg.str()); + } + py::tuple ret = py::cast(py_value_out); + if (ret.size() == 0) { + std::ostringstream msg; + msg << error_msg_tag << " The return value of the function " + << "whose gradient we want to compute should be either a " + << "scalar array or a non-empty tuple. The first value should be a " + << "scalar array and the rest can be anything. Instead, " + << "we got an empty tuple."; + throw std::invalid_argument(msg.str()); + } + if (!py::isinstance(ret[0])) { + std::ostringstream msg; + msg << error_msg_tag << " The return value of the function " + << "whose gradient we want to compute should be either a " + << "scalar array or a tuple with the first value being a " + << "scalar array (Union[array, Tuple[array, Any, ...]]); but it " + << "was a tuple with the first value being of type " + << ret[0].get_type() << " ."; + throw std::invalid_argument(msg.str()); + } + } + + return tree_flatten(py_value_out, false); + }, + gradient_indices)(arrays); + + auto value = value_and_grads.first; + auto gradients = value_and_grads.second; + + // Put the gradients back in their container. + // We have the following cases: + // + // 1. Single python positional argument has a gradient (eg argnums=[0]) + // 2. Many python positional arguments have gradients (eg argnums=[0, 1]) + // 3. A python keyword argument has gradients + // + // In case 1 we return the original python variable but with the gradients. + // In case 2 we return a tuple of the above. + // In case 3 we return a tuple containing a tuple and dict (sth like + // (tuple(), dict(x=mx.array(5))) ). + py::object positional_grads; + py::object keyword_grads; + py::object py_grads; + + // Collect the gradients for the positional arguments + if (argnums.size() == 1) { + positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]); + } else if (argnums.size() > 1) { + py::tuple grads_(argnums.size()); + for (int i = 0; i < argnums.size(); i++) { + grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]); + } + positional_grads = py::cast(grads_); + } else { + positional_grads = py::none(); + } + + // No keyword argument gradients so return the tuple of gradients + if (argnames.size() == 0) { + py_grads = positional_grads; + } else { + py::dict grads_; + for (int i = 0; i < argnames.size(); i++) { + auto& k = argnames[i]; + grads_[k.c_str()] = tree_unflatten( + kwargs[k.c_str()], gradients, counts[i + argnums.size()]); + } + keyword_grads = py::cast(grads_); + + py_grads = + py::cast(py::make_tuple(positional_grads, keyword_grads)); + } + + // Put the values back in the container + py::object return_value = tree_unflatten(py_value_out, value); + return std::make_pair(return_value, py_grads); + }; +} + +auto py_vmap( + const py::function& fun, + const py::object& in_axes, + const py::object& out_axes) { + return [fun, in_axes, out_axes](const py::args& args) { + auto axes_to_flat_tree = [](const py::object& tree, + const py::object& axes) { + auto tree_axes = tree_map( + {tree, axes}, + [](const std::vector& inputs) { return inputs[1]; }); + std::vector flat_axes; + tree_visit(tree_axes, [&flat_axes](py::handle obj) { + if (obj.is_none()) { + flat_axes.push_back(-1); + } else if (py::isinstance(obj)) { + flat_axes.push_back(py::cast(py::cast(obj))); + } else { + throw std::invalid_argument("[vmap] axis must be int or None."); + } + }); + return flat_axes; + }; + + // Inputs must be array or tree of arrays + auto inputs = tree_flatten(args, true); + auto flat_in_axes = axes_to_flat_tree(args, in_axes); + + // py_value_out will hold the output of the python function in order to be + // able to reconstruct the python tree of extra return values + py::object py_outputs; + + auto vmap_fn = + [&fun, &args, &inputs, &py_outputs](const std::vector& a) { + // Call the python function + py_outputs = fun(*tree_unflatten(args, a)); + + // Flatten the outputs + return tree_flatten(py_outputs, true); + }; + + auto [trace_inputs, trace_outputs] = + detail::vmap_trace(vmap_fn, inputs, flat_in_axes); + + auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes); + + // Perform the vmap + auto outputs = detail::vmap_replace( + inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes); + + // Put the outputs back in the container + return tree_unflatten(py_outputs, outputs); + }; +} + +void init_transforms(py::module_& m) { + m.def( + "eval", + [](const py::args& args, bool retain_graph) { + std::vector arrays = tree_flatten(args); + eval(arrays, retain_graph); + }, + "retain_graph"_a = false, + R"pbdoc( + Evaluate an :class:`array` or tree of :class:`array`. + + Args: + *args (arrays or trees of arrays): Each argument can be a single array + or a tree of arrays. If a tree is given the nodes can be a Python + :class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be + an :class:`array`. + retain_graph (bool): Indicate that the graph structure should be + preserved. This option is intended to enable function transforms + which contain control flow based on the value of an array. + )pbdoc"); + m.def( + "jvp", + [](const py::function& fun, + const std::vector& primals, + const std::vector& tangents) { + auto vfun = [&fun](const std::vector& primals) { + py::args args = py::tuple(primals.size()); + for (int i = 0; i < primals.size(); ++i) { + args[i] = primals[i]; + } + auto out = fun(*args); + if (py::isinstance(out)) { + return std::vector{py::cast(out)}; + } else { + return py::cast>(out); + } + }; + return jvp(vfun, primals, tangents); + }, + "fun"_a, + "primals"_a, + "tangents"_a, + R"pbdoc( + Compute the Jacobian-vector product. + + This computes the product of the Jacobian of a function ``fun`` evaluated + at ``primals`` with the ``tangents``. + + Args: + fun (function): A function which takes a variable number of :class:`array` + and returns a single :class:`array` or list of :class:`array`. + primals (list(array)): A list of :class:`array` at which to + evaluate the Jacobian. + tangents (list(array)): A list of :class:`array` which are the + "vector" in the Jacobian-vector product. The ``tangents`` should be the + same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``). + + Returns: + list(array): A list of the Jacobian-vector products which + is the same in number, shape, and type of the inputs to ``fun``. + )pbdoc"); + m.def( + "vjp", + [](const py::function& fun, + const std::vector& primals, + const std::vector& cotangents) { + auto vfun = [&fun](const std::vector& primals) { + py::args args = py::tuple(primals.size()); + for (int i = 0; i < primals.size(); ++i) { + args[i] = primals[i]; + } + auto out = fun(*args); + if (py::isinstance(out)) { + return std::vector{py::cast(out)}; + } else { + return py::cast>(out); + } + }; + return vjp(vfun, primals, cotangents); + }, + "fun"_a, + "primals"_a, + "cotangents"_a, + R"pbdoc( + Compute the vector-Jacobian product. + + Computes the product of the ``cotangents`` with the Jacobian of a + function ``fun`` evaluated at ``primals``. + + Args: + fun (function): A function which takes a variable number of :class:`array` + and returns a single :class:`array` or list of :class:`array`. + primals (list(array)): A list of :class:`array` at which to + evaluate the Jacobian. + cotangents (list(array)): A list of :class:`array` which are the + "vector" in the vector-Jacobian product. The ``cotangents`` should be the + same in number, shape, and type as the outputs of ``fun``. + + Returns: + list(array): A list of the vector-Jacobian products which + is the same in number, shape, and type of the outputs of ``fun``. + )pbdoc"); + m.def( + "value_and_grad", + [](const py::function& fun, + const std::optional& argnums, + const StrOrVec& argnames) { + auto [argnums_vec, argnames_vec] = + validate_argnums_argnames(argnums, argnames); + return py::cpp_function(py_value_and_grad( + fun, argnums_vec, argnames_vec, "[value_and_grad]", false)); + }, + "fun"_a, + "argnums"_a = std::nullopt, + "argnames"_a = std::vector{}, + R"pbdoc( + Returns a function which computes the value and gradient of ``fun``. + + The function passed to :func:`value_and_grad` should return either + a scalar loss or a tuple in which the first element is a scalar + loss and the remaining elements can be anything. + + .. code-block:: python + + import mlx.core as mx + + def mse(params, inputs, targets): + outputs = forward(params, inputs) + lvalue = (outputs - targets).square().mean() + return lvalue + + # Returns lvalue, dlvalue/dparams + lvalue, grads = mx.value_and_grad(mse) + + def lasso(params, inputs, targets, a=1.0, b=1.0): + outputs = forward(params, inputs) + mse = (outputs - targets).square().mean() + l1 = mx.abs(outputs - targets).mean() + + loss = a*mse + b*l1 + + return loss, mse, l1 + + (loss, mse, l1), grads = mx.value_and_grad(lasso) + + Args: + fun (function): A function which takes a variable number of + :class:`array` or trees of :class:`array` and returns + a scalar output :class:`array` or a tuple the first element + of which should be a scalar :class:`array`. + argnums (int or list(int), optional): Specify the index (or indices) + of the positional arguments of ``fun`` to compute the gradient + with respect to. If neither ``argnums`` nor ``argnames`` are + provided ``argnums`` defaults to ``0`` indicating ``fun``'s first + argument. + argnames (str or list(str), optional): Specify keyword arguments of + ``fun`` to compute gradients with respect to. It defaults to [] so + no gradients for keyword arguments by default. + + Returns: + function: A function which returns a tuple where the first element + is the output of `fun` and the second element is the gradients w.r.t. + the loss. + )pbdoc"); + m.def( + "grad", + [](const py::function& fun, + const std::optional& argnums, + const StrOrVec& argnames) { + auto [argnums_vec, argnames_vec] = + validate_argnums_argnames(argnums, argnames); + auto fn = + py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true); + return py::cpp_function( + [fn](const py::args& args, const py::kwargs& kwargs) { + return fn(args, kwargs).second; + }); + }, + "fun"_a, + "argnums"_a = std::nullopt, + "argnames"_a = std::vector{}, + R"pbdoc( + Returns a function which computes the gradient of ``fun``. + + Args: + fun (function): A function which takes a variable number of + :class:`array` or trees of :class:`array` and returns + a scalar output :class:`array`. + argnums (int or list(int), optional): Specify the index (or indices) + of the positional arguments of ``fun`` to compute the gradient + with respect to. If neither ``argnums`` nor ``argnames`` are + provided ``argnums`` defaults to ``0`` indicating ``fun``'s first + argument. + argnames (str or list(str), optional): Specify keyword arguments of + ``fun`` to compute gradients with respect to. It defaults to [] so + no gradients for keyword arguments by default. + + Returns: + function: A function which has the same input arguments as ``fun`` and + returns the gradient(s). + )pbdoc"); + m.def( + "vmap", + [](const py::function& fun, + const py::object& in_axes, + const py::object& out_axes) { + return py::cpp_function(py_vmap(fun, in_axes, out_axes)); + }, + "fun"_a, + "in_axes"_a = 0, + "out_axes"_a = 0, + R"pbdoc( + Returns a vectorized version of ``fun``. + + Args: + fun (function): A function which takes a variable number of + :class:`array` or a tree of :class:`array` and returns + a variable number of :class:`array` or a tree of :class:`array`. + in_axes (int, optional): An integer or a valid prefix tree of the + inputs to ``fun`` where each node specifies the vmapped axis. If + the value is ``None`` then the corresponding input(s) are not vmapped. + Defaults to ``0``. + out_axes (int, optional): An integer or a valid prefix tree of the + outputs of ``fun`` where each node specifies the vmapped axis. If + the value is ``None`` then the corresponding outputs(s) are not vmapped. + Defaults to ``0``. + + Returns: + function: The vectorized function. + )pbdoc"); + m.def( + "simplify", + [](const py::args& args) { + std::vector arrays = tree_flatten(args); + simplify(arrays); + }, + R"pbdoc( + Simplify the graph that computes the arrays. + + Run a few fast graph simplification operations to reuse computation and + reduce memory consumption. This function is meant to be run every time + so its overhead should be small, approximately 1ms for a graph with a + few thousand nodes. + + .. code-block:: python + + import mlx.core as mx + + def foo(x): + y = x @ x + z = x @ x + return y + z + + x = mx.ones((10, 10)) + y = foo(x) + z = foo(x) + + # Computes the matmul twice + mx.eval(y) + + # Computes the matmul once + mx.simplify(z) + mx.eval(z) + + Args: + args: Any number of arrays and/or trees of arrays to be simplified. + )pbdoc"); + m.def( + "export_to_dot", + [](py::object file, const py::args& args) { + std::vector arrays = tree_flatten(args); + if (py::isinstance(file)) { + std::ofstream out(py::cast(file)); + export_to_dot(out, arrays); + } else if (py::hasattr(file, "write")) { + std::ostringstream out; + export_to_dot(out, arrays); + auto write = file.attr("write"); + write(out.str()); + } else { + throw std::invalid_argument( + "export_to_dot accepts file-like objects or strings to be used as filenames"); + } + }, + "file"_a); +} diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py new file mode 100644 index 000000000..6c19f6eb8 --- /dev/null +++ b/python/tests/mlx_tests.py @@ -0,0 +1,16 @@ +import os +import unittest + +import mlx.core as mx + + +class MLXTestCase(unittest.TestCase): + def setUp(self): + self.default = mx.default_device() + device = os.getenv("DEVICE", None) + if device is not None: + device = getattr(mx, device) + mx.set_default_device(device) + + def tearDown(self): + mx.set_default_device(self.default) diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py new file mode 100644 index 000000000..3a81e10c7 --- /dev/null +++ b/python/tests/test_blas.py @@ -0,0 +1,445 @@ +import unittest +from itertools import permutations + +import math +import mlx.core as mx +import numpy as np + +import mlx_tests + + +class TestBlas(mlx_tests.MLXTestCase): + @property + def dtypes(self): + return ["float32", "float16"] if mx.metal.is_available() else ["float32"] + + def __gemm_test( + self, + shape_a, + shape_b, + np_dtype=np.float32, + f_np_a=lambda x: x, + f_np_b=lambda x: x, + f_mx_a=lambda x: x, + f_mx_b=lambda x: x, + ): + with self.subTest( + dtype=np.dtype(np_dtype).name, shape_a=shape_a, shape_b=shape_b + ): + np.random.seed(42) + scale = max(np.sum(shape_a), 128) + a_np = np.random.normal(0.0, 1.0 / scale, shape_a).astype(np_dtype) + b_np = np.random.normal(0.0, 1.0 / scale, shape_b).astype(np_dtype) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_np = f_np_a(a_np.astype(np.float32)) + b_np = f_np_b(b_np.astype(np.float32)) + a_mx = f_mx_a(a_mx) + b_mx = f_mx_b(b_mx) + + out_npy = a_np @ b_np + out_mlx = a_mx @ b_mx + + self.assertListEqual(list(out_npy.shape), list(out_mlx.shape)) + self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5)) + + def test_matmul_unaligned(self): + if not mx.metal.is_available(): + return + + for dtype in self.dtypes: + np_dtype = getattr(np, dtype) + base_shapes = [4, 8, 16, 32, 64, 128] + pertubations = [-2, -1, 0, 1, 2] + + for dim in base_shapes: + for p in pertubations: + shape_a = (dim + p, dim + p) + shape_b = (dim + p, dim + p) + self.__gemm_test(shape_a, shape_b, np_dtype) + + def test_matmul_shapes(self): + if not mx.metal.is_available(): + return + + shapes = [ + (1, 2, 1, 1), + (1, 1, 2, 1), + (3, 23, 457, 3), + ] + + if mx.default_device() == mx.gpu: + shapes += [ + (16, 768, 768, 128), + ] + + for dtype in self.dtypes: + np_dtype = getattr(np, dtype) + + for B, M, N, K in shapes: + + with self.subTest(tranpose="nn"): + shape_a = (B, M, K) + shape_b = (B, K, N) + self.__gemm_test(shape_a, shape_b, np_dtype) + + with self.subTest(tranpose="nt"): + shape_a = (B, M, K) + shape_b = (B, N, K) + self.__gemm_test( + shape_a, + shape_b, + np_dtype, + f_np_b=lambda x: np.transpose(x, (0, 2, 1)), + f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)), + ) + + with self.subTest(tranpose="tn"): + shape_a = (B, K, M) + shape_b = (B, K, N) + self.__gemm_test( + shape_a, + shape_b, + np_dtype, + f_np_a=lambda x: np.transpose(x, (0, 2, 1)), + f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)), + ) + + with self.subTest(tranpose="tt"): + shape_a = (B, K, M) + shape_b = (B, N, K) + self.__gemm_test( + shape_a, + shape_b, + np_dtype, + f_np_a=lambda x: np.transpose(x, (0, 2, 1)), + f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)), + f_np_b=lambda x: np.transpose(x, (0, 2, 1)), + f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)), + ) + + def test_matmul(self): + # Note: so far, matmul only works with floating-point types + a = mx.array([[1.0, 2.0], [3.0, 4.0]]) + + b = mx.array([[0.0, -1.0], [-3.0, 3.0]]) + + expected = [[-6.0, 5.0], [-12.0, 9.0]] + + self.assertEqual((a @ b).tolist(), expected) + self.assertEqual(mx.matmul(a, b).tolist(), expected) + + # Transposed matmul + np.random.seed(0) + a_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32) + c_npy = a_npy @ np.transpose(b_npy, (1, 0)) + d_npy = np.transpose(a_npy, (1, 0)) @ b_npy + + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0)) + d_mlx = mx.transpose(a_mlx, (1, 0)) @ b_mlx + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6)) + + def test_matmul_dtypes(self): + + for dt in self.dtypes: + a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype( + getattr(np, dt) + ) + b_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype( + getattr(np, dt) + ) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + c_npy = np.matmul(a_npy, b_npy, dtype=getattr(np, dt)) + c_mlx = a_mlx @ b_mlx + + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6)) + + def test_matmul_batched(self): + np.random.seed(0) + # Batched matmul + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) + c_npy = a_npy @ b_npy + + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + c_mlx = a_mlx @ b_mlx + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6)) + + # Batched and transposed matmul + b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + c_npy = a_npy @ np.transpose(b_npy, (0, 2, 1)) + + b_mlx = mx.array(b_npy) + c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 2, 1)) + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6)) + + # Batched matmul with simple broadast + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32) + c_npy = a_npy @ b_npy + + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + c_mlx = a_mlx @ b_mlx + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6)) + + # Both operands broadcasted + d_npy = np.broadcast_to(b_npy, (5, 16, 16)) + d_mlx = mx.broadcast_to(b_mlx, (5, 16, 16)) + + e_npy = d_npy @ d_npy + e_mlx = d_mlx @ d_mlx + + self.assertListEqual(list(e_npy.shape), list(e_mlx.shape)) + self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6)) + + # Batched and transposed matmul with simple broadast + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + c_npy = a_npy @ np.transpose(b_npy, (1, 0)) + c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0)) + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6)) + + # Matmul with vector + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + c_npy = a_npy @ b_npy + c_mlx = a_mlx @ b_mlx + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6)) + + # Test Multiheaded attention style matmul + a_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + a_npy = np.transpose(a_npy, (0, 2, 1, 3)) + b_npy = np.transpose(b_npy, (0, 2, 1, 3)) + a_mlx = mx.transpose(a_mlx, (0, 2, 1, 3)) + b_mlx = mx.transpose(b_mlx, (0, 2, 1, 3)) + + c_npy = a_npy @ np.transpose(b_npy, (0, 1, 3, 2)) + c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 1, 3, 2)) + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6)) + + def __gemv_test( + self, + shape_mat, + shape_vec, + np_dtype=np.float32, + mat_first=True, + np_mat_f=lambda x: x, + np_vec_f=lambda x: x, + mlx_mat_f=lambda x: x, + mlx_vec_f=lambda x: x, + ): + with self.subTest(shape=shape_mat): + np.random.seed(42) + scale = max(np.sum(shape_mat), 32) + mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype) + vec_npy = np.random.normal(0.0, 1.0 / scale, shape_vec).astype(np_dtype) + + mat_mlx = mx.array(mat_npy) + vec_mlx = mx.array(vec_npy) + + mat_npy = np_mat_f(mat_npy) + vec_npy = np_vec_f(vec_npy) + mat_mlx = mlx_mat_f(mat_mlx) + vec_mlx = mlx_vec_f(vec_mlx) + + if mat_first: + out_npy = mat_npy @ vec_npy + out_mlx = mat_mlx @ vec_mlx + else: + out_npy = vec_npy @ mat_npy + out_mlx = vec_mlx @ mat_mlx + + self.assertListEqual(list(out_npy.shape), list(out_mlx.shape)) + self.assertTrue(np.allclose(out_mlx, out_npy, atol=1e-5)) + + def test_matrix_vector(self): + for dtype in self.dtypes: + with self.subTest(dtype=dtype): + np_dtype = getattr(np, dtype) + + # Basic square matrix test + self.__gemv_test( + shape_mat=(64, 64), shape_vec=(64, 1), np_dtype=np_dtype + ) + self.__gemv_test( + shape_mat=(64, 64), + shape_vec=(64, 1), + np_dtype=np_dtype, + mat_first=False, + np_vec_f=lambda x: np.transpose(x, (1, 0)), + mlx_vec_f=lambda x: mx.transpose(x, (1, 0)), + ) + + # Vector matrix product with aligned and unaligned shapes + for in_len_base, out_len_base in ( + (2, 2), + (32, 32), + (64, 64), + (2048, 2048), + ): + for mi in (-1, 0, 1): + for mj in (-1, 0, 1): + # Vec mat + shape_mat = (in_len_base + mi, out_len_base + mj) + shape_vec = (1, in_len_base + mi) + self.__gemv_test( + shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype + ) + + # Mat vec + shape_mat = (out_len_base + mj, in_len_base + mi) + shape_vec = (in_len_base + mi, 1) + self.__gemv_test( + shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype + ) + + def test_matrix_vector_batched(self): + for dtype in self.dtypes: + with self.subTest(dtype=dtype): + np_dtype = getattr(np, dtype) + + # Batched mat vec + for shape_mat, shape_vec in ( + ((32, 128, 64), (32, 64, 1)), + ((128, 64), (32, 64, 1)), + ((32, 128, 64), (64, 1)), + ): + self.__gemv_test( + shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype + ) + + # Batched vec mat + for shape_vec, shape_mat in ( + ((32, 1, 128), (32, 128, 64)), + ((32, 1, 128), (128, 64)), + ((1, 128), (32, 128, 64)), + ): + self.__gemv_test( + shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype + ) + + def test_matrix_vector_broadcast(self): + for dtype in self.dtypes: + with self.subTest(dtype=dtype): + np_dtype = getattr(np, dtype) + + # Different broadcasts mat vec + for shape_mat, shape_vec in ( + ((32, 64, 64), (32, 64, 1)), + ((64, 64), (32, 64, 1)), + ((32, 64, 64), (64, 1)), + ): + self.__gemv_test( + shape_mat=(64, 64), + shape_vec=(64, 1), + np_dtype=np_dtype, + np_mat_f=(lambda mat_npy: np.broadcast_to(mat_npy, shape_mat)), + np_vec_f=(lambda vec_npy: np.broadcast_to(vec_npy, shape_vec)), + mlx_mat_f=(lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat)), + mlx_vec_f=(lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec)), + ) + + # Different broadcasts vec mat + for shape_vec, shape_mat in ( + ((32, 1, 64), (32, 64, 64)), + ((32, 1, 64), (64, 64)), + ((1, 64), (32, 64, 64)), + ): + self.__gemv_test( + shape_mat=(64, 64), + shape_vec=(1, 64), + np_dtype=np_dtype, + mat_first=False, + np_mat_f=lambda mat_npy: np.broadcast_to(mat_npy, shape_mat), + np_vec_f=lambda vec_npy: np.broadcast_to(vec_npy, shape_vec), + mlx_mat_f=lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat), + mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec), + ) + + def test_matrix_vector_edgecases(self): + for dtype in self.dtypes: + with self.subTest(dtype=dtype): + np_dtype = getattr(np, dtype) + + for in_vec_len in np.arange(1, 5): + for out_vec_len in np.arange(1, 5): + for batch_size in np.arange(1, 5): + with self.subTest( + problem_shape=(batch_size, in_vec_len, out_vec_len) + ): + # Matrix vector + with self.subTest(transpose=False): + a_npy = np.ones( + (batch_size, out_vec_len, in_vec_len), + dtype=np_dtype, + ) + b_npy = np.ones( + (batch_size, in_vec_len, 1), dtype=np_dtype + ) + for i in range(batch_size): + b_npy[i] *= i + 1.0 + + a_mlx, b_mlx = map(mx.array, [a_npy, b_npy]) + c_npy = a_npy @ b_npy + c_mlx = a_mlx @ b_mlx + + self.assertListEqual( + list(c_npy.shape), list(c_mlx.shape) + ) + self.assertTrue(np.array_equal(c_mlx, c_npy)) + + # Vector matrix + with self.subTest(transpose=True): + a_npy = np.ones( + (batch_size, out_vec_len, in_vec_len), + dtype=np_dtype, + ) + b_npy = np.ones( + (batch_size, 1, out_vec_len), dtype=np_dtype + ) + for i in range(batch_size): + b_npy[i] *= i + 1.0 + + a_mlx, b_mlx = map(mx.array, [a_npy, b_npy]) + c_npy = b_npy @ a_npy + c_mlx = b_mlx @ a_mlx + + self.assertListEqual( + list(c_npy.shape), list(c_mlx.shape) + ) + self.assertTrue(np.array_equal(c_mlx, c_npy)) diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py new file mode 100644 index 000000000..c4df60357 --- /dev/null +++ b/python/tests/test_conv.py @@ -0,0 +1,445 @@ +import unittest +from itertools import permutations + +import math +import mlx.core as mx +import numpy as np + +import mlx_tests + +try: + import torch + import torch.nn.functional as F + + has_torch = True +except ImportError as e: + has_torch = False + + +class TestConv(mlx_tests.MLXTestCase): + def test_numpy_conv(self): + for dtype in ( + "float16", + "float32", + ): + np_dtype = getattr(np, dtype) + for M, N, mode in ( + (1, 1, "full"), + (25, 5, "full"), + (24, 5, "same"), + (24, 4, "same"), + (24, 4, "valid"), + (4, 24, "full"), + (5, 25, "same"), + (4, 25, "valid"), + ): + with self.subTest(dtype=dtype, M=M, N=N, mode=mode): + atol = 1e-6 if dtype == "float32" else 1e-5 + a_np = np.random.rand(M).astype(np_dtype) + v_np = np.random.rand(N).astype(np_dtype) + a_mx = mx.array(a_np) + v_mx = mx.array(v_np) + + c_np = np.convolve(a_np, v_np, mode=mode) + c_mx = mx.convolve(a_mx, v_mx, mode=mode) + + self.assertListEqual(list(c_mx.shape), list(c_np.shape)) + self.assertTrue(np.allclose(c_mx, c_np, atol=atol)) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_1D(self): + def run_conv1D( + N, + C, + O, + iH, + kH, + stride, + padding, + dilation=1, + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + iH=iH, + kH=kH, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt, wt_pt = map( + lambda x: torch.from_numpy(x.transpose(0, 2, 1)), (in_np, wt_np) + ) + + out_mx = mx.conv1d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.conv1d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.transpose(out_pt, 2, 1) + + self.assertListEqual(list(out_pt.shape), out_mx.shape) + self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ( + (1, 1, 1), + (1, 6, 1), + (1, 1, 6), + (4, 32, 64), + ): + for iH, kH, stride, padding in ( + (1, 1, 1, 0), + (3, 3, 1, 0), + (31, 5, 5, 2), + ): + run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype) + + # Strided inputs tests + for tpose_in, tpose_wt in ( + ((0, 2, 1), (0, 1, 2)), + ((0, 2, 1), (0, 2, 1)), + ): + with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt): + in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32) + wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_mx_t = mx.transpose(in_mx, tpose_in) + wt_mx_t = mx.transpose(wt_mx, tpose_wt) + out_mx = mx.conv1d(in_mx_t, wt_mx_t) + + in_pt, wt_pt = map( + lambda x: torch.from_numpy(x.transpose(0, 2, 1)), + (in_np.transpose(tpose_in), wt_np.transpose(tpose_wt)), + ) + + out_pt = torch.conv1d(in_pt, wt_pt) + out_pt = torch.transpose(out_pt, 2, 1) + + self.assertListEqual(list(out_pt.shape), out_mx.shape) + self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5)) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_1D_grad(self): + def run_conv1D_grad( + N, + C, + O, + iH, + kH, + stride, + padding, + dilation=1, + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + iH=iH, + kH=kH, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride) + + in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) + ct_np = np.random.normal(0, 1.0 / C, (N, oH, O)).astype(np_dtype) + + in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np)) + in_pt, wt_pt, ct_pt = map( + lambda x: torch.from_numpy(x.transpose(0, 2, 1)), + (in_np, wt_np, ct_np), + ) + + def f(a, b): + return mx.conv1d( + a, + b, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + _, outs_mx = mx.vjp( + f, + [ + in_mx, + wt_mx, + ], + [ + ct_mx, + ], + ) + pt_grad_in = F.grad.conv1d_input( + in_pt.shape, + wt_pt, + ct_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + pt_grad_wt = F.grad.conv1d_weight( + in_pt, + wt_pt.shape, + ct_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + pt_grad_in = torch.transpose(pt_grad_in, 2, 1).numpy() + pt_grad_wt = torch.transpose(pt_grad_wt, 2, 1).numpy() + + mx_grad_in, mx_grad_wt = outs_mx + + self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape) + self.assertListEqual(list(in_mx.shape), mx_grad_in.shape) + self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) + + self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape) + self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape) + self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ( + (1, 1, 1), + (1, 6, 1), + (1, 1, 6), + (4, 32, 64), + ): + for iH, kH, stride, padding in ( + (1, 1, 1, 0), + (3, 3, 1, 0), + (31, 5, 5, 2), + ): + run_conv1D_grad(N, C, O, iH, kH, stride, padding, dtype=dtype) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_2D(self): + def run_conv2D( + N, + C, + O, + idim, + kdim, + stride, + padding, + dilation=(1, 1), + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + kH, kW = kdim + scale = 1.0 / math.sqrt(kH * kW * C) + in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) + wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt, wt_pt = map( + lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"), + (in_np, wt_np), + ) + + out_mx = mx.conv2d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.conv2d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + + self.assertListEqual(list(out_pt.shape), list(out_mx.shape)) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ( + (1, 1, 1), + (1, 6, 1), + (1, 1, 6), + (4, 32, 64), + ): + for idim, kdim, stride, padding in ( + ((1, 1), (1, 1), (1, 1), (0, 0)), + ((3, 3), (3, 1), (1, 1), (0, 0)), + ((31, 31), (5, 5), (5, 5), (2, 2)), + ): + run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_2D_grad(self): + def run_conv2D_grad( + N, + C, + O, + idim, + kdim, + stride, + padding, + dilation=(1, 1), + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + kH, kW = kdim + scale = 1.0 / math.sqrt(kH * kW * C) + + oH = 1 + ( + (iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0] + ) + oW = 1 + ( + (iW + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1] + ) + + in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) + wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype) + ct_np = np.random.normal(0.0, scale, (N, oH, oW, O)).astype(np_dtype) + + in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np)) + in_pt, wt_pt, ct_pt = map( + lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"), + (in_np, wt_np, ct_np), + ) + + def f(a, b): + return mx.conv2d( + a, + b, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + _, outs_mx = mx.vjp( + f, + [ + in_mx, + wt_mx, + ], + [ + ct_mx, + ], + ) + pt_grad_in = F.grad.conv1d_input( + in_pt.shape, + wt_pt, + ct_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + pt_grad_wt = F.grad.conv1d_weight( + in_pt, + wt_pt.shape, + ct_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 1)).numpy() + pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 1)).numpy() + + mx_grad_in, mx_grad_wt = outs_mx + + self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape) + self.assertListEqual(list(in_mx.shape), mx_grad_in.shape) + self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) + + self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape) + self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape) + self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ( + (1, 1, 1), + (1, 6, 1), + (1, 1, 6), + (4, 32, 64), + ): + for idim, kdim, stride, padding in ( + ((1, 1), (1, 1), (1, 1), (0, 0)), + ((3, 3), (3, 1), (1, 1), (0, 0)), + ((31, 31), (5, 5), (5, 5), (2, 2)), + ): + run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_load.py b/python/tests/test_load.py new file mode 100644 index 000000000..ebaddf985 --- /dev/null +++ b/python/tests/test_load.py @@ -0,0 +1,157 @@ +import unittest +import os +import mlx.core as mx +import numpy as np +import tempfile + +import mlx_tests + + +class TestLoad(mlx_tests.MLXTestCase): + dtypes = [ + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float32", + "float16", + "complex64", + ] + + @classmethod + def setUpClass(cls): + cls.test_dir_fid = tempfile.TemporaryDirectory() + cls.test_dir = cls.test_dir_fid.name + + @classmethod + def tearDownClass(cls): + cls.test_dir_fid.cleanup() + + def test_save_and_load(self): + + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + for dt in self.dtypes: + with self.subTest(dtype=dt): + for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): + with self.subTest(shape=shape): + save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy") + save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy") + + save_arr = np.random.uniform(0.0, 32.0, size=shape) + save_arr_npy = save_arr.astype(getattr(np, dt)) + save_arr_mlx = mx.array(save_arr_npy) + + mx.save(save_file_mlx, save_arr_mlx) + np.save(save_file_npy, save_arr_npy) + + # Load array saved by mlx as mlx array + load_arr_mlx_mlx = mx.load(save_file_mlx) + self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx)) + + # Load array saved by numpy as mlx array + load_arr_npy_mlx = mx.load(save_file_npy) + self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx)) + + # Load array saved by mlx as numpy array + load_arr_mlx_npy = np.load(save_file_mlx) + self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) + + def test_save_and_load_fs(self): + + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + for dt in self.dtypes: + with self.subTest(dtype=dt): + for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): + with self.subTest(shape=shape): + save_file_mlx = os.path.join( + self.test_dir, f"mlx_{dt}_{i}_fs.npy" + ) + save_file_npy = os.path.join( + self.test_dir, f"npy_{dt}_{i}_fs.npy" + ) + + save_arr = np.random.uniform(0.0, 32.0, size=shape) + save_arr_npy = save_arr.astype(getattr(np, dt)) + save_arr_mlx = mx.array(save_arr_npy) + + with open(save_file_mlx, "wb") as f: + mx.save(f, save_arr_mlx) + + np.save(save_file_npy, save_arr_npy) + + # Load array saved by mlx as mlx array + with open(save_file_mlx, "rb") as f: + load_arr_mlx_mlx = mx.load(f) + self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx)) + + # Load array saved by numpy as mlx array + with open(save_file_npy, "rb") as f: + load_arr_npy_mlx = mx.load(f) + self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx)) + + # Load array saved by mlx as numpy array + load_arr_mlx_npy = np.load(save_file_mlx) + self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) + + def test_savez_and_loadz(self): + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + for dt in self.dtypes: + with self.subTest(dtype=dt): + shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)] + save_file_mlx_uncomp = os.path.join( + self.test_dir, f"mlx_{dt}_uncomp.npz" + ) + save_file_npy_uncomp = os.path.join( + self.test_dir, f"npy_{dt}_uncomp.npz" + ) + save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz") + save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz") + + # Make dictionary of multiple + save_arrs_npy = { + f"save_arr_{i}": np.random.uniform( + 0.0, 32.0, size=shapes[i] + ).astype(getattr(np, dt)) + for i in range(len(shapes)) + } + save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()} + + # Save as npz files + np.savez(save_file_npy_uncomp, **save_arrs_npy) + mx.savez(save_file_mlx_uncomp, **save_arrs_mlx) + np.savez_compressed(save_file_npy_comp, **save_arrs_npy) + mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx) + + for save_file_npy, save_file_mlx in ( + (save_file_npy_uncomp, save_file_mlx_uncomp), + (save_file_npy_comp, save_file_mlx_comp), + ): + + # Load array saved by mlx as mlx array + load_arr_mlx_mlx = mx.load(save_file_mlx) + for k, v in load_arr_mlx_mlx.items(): + self.assertTrue(mx.array_equal(save_arrs_mlx[k], v)) + + # Load arrays saved by numpy as mlx arrays + load_arr_npy_mlx = mx.load(save_file_npy) + for k, v in load_arr_npy_mlx.items(): + self.assertTrue(mx.array_equal(save_arrs_mlx[k], v)) + + # Load array saved by mlx as numpy array + load_arr_mlx_npy = np.load(save_file_mlx) + for k, v in load_arr_mlx_npy.items(): + self.assertTrue(np.array_equal(save_arrs_npy[k], v)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py new file mode 100644 index 000000000..cc664cbb0 --- /dev/null +++ b/python/tests/test_nn.py @@ -0,0 +1,231 @@ +import unittest + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten, tree_map, tree_unflatten +import numpy as np +import os +import tempfile + +import mlx_tests + + +class TestNN(mlx_tests.MLXTestCase): + def test_linear(self): + inputs = mx.zeros((10, 4)) + layer = nn.Linear(input_dims=4, output_dims=8) + outputs = layer(inputs) + self.assertEqual(tuple(outputs.shape), (10, 8)) + + def test_cross_entropy(self): + logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) + targets = mx.array([0, 1]) + losses = nn.losses.cross_entropy(logits, targets) + self.assertTrue(mx.array_equal(losses, mx.zeros((2,)))) + + def test_gelu(self): + inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414] + + # From: jax.nn.gelu(np.array(inputs), approximate=False) + expected = np.array( + [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383] + ) + + out = nn.GELU()(mx.array(inputs)) + self.assertTrue(np.allclose(out, expected)) + + # Crudely check the approximations + x = mx.arange(-6.0, 6.0, 12 / 100) + y = nn.gelu(x) + y_hat1 = nn.gelu_approx(x) + y_hat2 = nn.gelu_fast_approx(x) + self.assertLess(mx.abs(y - y_hat1).max(), 0.0003) + self.assertLess(mx.abs(y - y_hat2).max(), 0.02) + + def test_group_norm(self): + x = mx.arange(100, dtype=mx.float32) + x = x.reshape(1, 10, 10, 1) + x = mx.broadcast_to(x, (2, 10, 10, 4)) + x = mx.concatenate([x, 0.5 * x], axis=-1) + + # Group norm in groups last mode + g = nn.GroupNorm(2, 8) + y = g(x) + means = y.reshape(2, -1, 2).mean(axis=1) + var = y.reshape(2, -1, 2).var(axis=1) + self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6)) + self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6)) + g.weight = g.weight * 2 + g.bias = g.bias + 3 + y = g(x) + means = y.reshape(2, -1, 2).mean(axis=1) + var = y.reshape(2, -1, 2).var(axis=1) + self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6)) + self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6)) + + # Group norm in groups first mode + g = nn.GroupNorm(2, 8, pytorch_compatible=True) + y = g(x) + means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1)) + var = y.reshape(2, -1, 2, 4).var(axis=(1, -1)) + self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6)) + self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6)) + g.weight = g.weight * 2 + g.bias = g.bias + 3 + y = g(x) + means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1)) + var = y.reshape(2, -1, 2, 4).var(axis=(1, -1)) + self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6)) + self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6)) + + def test_conv1d(self): + N = 5 + L = 12 + ks = 3 + C_in = 2 + C_out = 4 + x = mx.ones((N, L, C_in)) + c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks) + c.weight = mx.ones_like(c.weight) + y = c(x) + self.assertEqual(y.shape, [N, L - ks + 1, C_out]) + self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32))) + + c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2) + y = c(x) + self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out]) + self.assertTrue("bias" in c.parameters()) + + c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False) + self.assertTrue("bias" not in c.parameters()) + + def test_conv2d(self): + x = mx.ones((4, 8, 8, 3)) + c = nn.Conv2d(3, 1, 8) + y = c(x) + self.assertEqual(y.shape, [4, 1, 1, 1]) + c.weight = mx.ones_like(c.weight) / 8 / 8 / 3 + y = c(x) + self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3)))) + + # 3x3 conv no padding stride 1 + c = nn.Conv2d(3, 8, 3) + y = c(x) + self.assertEqual(y.shape, [4, 6, 6, 8]) + self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) + + # 3x3 conv padding 1 stride 1 + c = nn.Conv2d(3, 8, 3, padding=1) + y = c(x) + self.assertEqual(y.shape, [4, 8, 8, 8]) + self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4) + self.assertLess( + mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(), + 1e-4, + ) + self.assertLess( + mx.abs(y[:, 7, 7] - c.weight[:, :-1, :-1].sum(axis=(1, 2, 3))).max(), + 1e-4, + ) + self.assertLess( + mx.abs(y[:, 1:7, 7] - c.weight[:, :, :-1].sum(axis=(1, 2, 3))).max(), + 1e-4, + ) + self.assertLess( + mx.abs(y[:, 7, 1:7] - c.weight[:, :-1, :].sum(axis=(1, 2, 3))).max(), + 1e-4, + ) + + # 3x3 conv no padding stride 2 + c = nn.Conv2d(3, 8, 3, padding=0, stride=2) + y = c(x) + self.assertEqual(y.shape, [4, 3, 3, 8]) + self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) + + def test_sequential(self): + x = mx.ones((10, 2)) + m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1)) + y = m(x) + self.assertEqual(y.shape, [10, 1]) + params = m.parameters() + self.assertTrue("layers" in params) + self.assertEqual(len(params["layers"]), 3) + self.assertTrue("weight" in params["layers"][0]) + self.assertEqual(len(params["layers"][1]), 0) + self.assertTrue("weight" in params["layers"][2]) + + m.layers[1] = nn.relu + y2 = m(x) + self.assertTrue(mx.array_equal(y, y2)) + + def test_module_utilities(self): + m = nn.Sequential( + nn.Sequential(nn.Linear(2, 10), nn.relu), + nn.Sequential(nn.Linear(10, 10), nn.ReLU()), + nn.Linear(10, 1), + mx.sigmoid, + ) + + children = m.children() + self.assertTrue(isinstance(children, dict)) + self.assertEqual(len(children), 1) + self.assertTrue(isinstance(children["layers"], list)) + self.assertEqual(len(children["layers"]), 4) + self.assertEqual(children["layers"][3], {}) + flat_children = tree_flatten(children, is_leaf=nn.Module.is_module) + self.assertEqual(len(flat_children), 3) + + leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module) + self.assertEqual(len(leaves), 4) + self.assertEqual(leaves[0][0], "layers.0.layers.0") + self.assertEqual(leaves[1][0], "layers.1.layers.0") + self.assertEqual(leaves[2][0], "layers.1.layers.1") + self.assertEqual(leaves[3][0], "layers.2") + self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) + self.assertTrue(leaves[1][1] is m.layers[1].layers[0]) + self.assertTrue(leaves[2][1] is m.layers[1].layers[1]) + self.assertTrue(leaves[3][1] is m.layers[2]) + + m.eval() + + def assert_not_training(k, m): + self.assertFalse(m.training) + + m.apply_to_modules(assert_not_training) + + m.train() + + def assert_training(k, m): + self.assertTrue(m.training) + + m.apply_to_modules(assert_training) + + def test_sin_pe(self): + m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01) + x = mx.arange(10) + y = m(x) + + self.assertEqual(y.shape, [10, 16]) + similarities = y @ y.T + self.assertLess( + mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5 + ) + + def test_io(self): + def make_model(): + return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) + + m = make_model() + tdir = tempfile.TemporaryDirectory() + file = os.path.join(tdir.name, "model.npz") + m.save_weights(file) + m_load = make_model() + m_load.load_weights(file) + tdir.cleanup() + + eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters()) + self.assertTrue(all(tree_flatten(eq_tree))) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py new file mode 100644 index 000000000..7cfc523f3 --- /dev/null +++ b/python/tests/test_optimizers.py @@ -0,0 +1,29 @@ +import unittest + +import mlx.core as mx +import mlx.optimizers as opt +import mlx.utils + +import mlx_tests + + +class TestOptimizers(mlx_tests.MLXTestCase): + def test_optimizers(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) + + for optim in [opt.SGD(0.1), opt.Adam(0.1)]: + update = optim.apply_gradients(grads, params) + mx.eval(update) + equal_shape = mlx.utils.tree_map( + lambda x, y: x.shape == y.shape, params, update + ) + all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) + self.assertTrue(all_equal) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_random.py b/python/tests/test_random.py new file mode 100644 index 000000000..209d6d9e5 --- /dev/null +++ b/python/tests/test_random.py @@ -0,0 +1,192 @@ +import unittest + +import mlx.core as mx + +import mlx_tests + + +class TestRandom(mlx_tests.MLXTestCase): + def test_global_rng(self): + mx.random.seed(3) + a = mx.random.uniform() + b = mx.random.uniform() + + mx.random.seed(3) + x = mx.random.uniform() + y = mx.random.uniform() + + self.assertEqual(a.item(), x.item()) + self.assertEqual(y.item(), b.item()) + + def test_key(self): + k1 = mx.random.key(0) + k2 = mx.random.key(0) + self.assertTrue(mx.array_equal(k1, k2)) + + k2 = mx.random.key(1) + self.assertFalse(mx.array_equal(k1, k2)) + + def test_key_split(self): + key = mx.random.key(0) + + k1, k2 = mx.random.split(key) + self.assertFalse(mx.array_equal(k1, k2)) + + r1, r2 = mx.random.split(key) + self.assertTrue(mx.array_equal(k1, r1)) + self.assertTrue(mx.array_equal(k2, r2)) + + keys = mx.random.split(key, 10) + self.assertEqual(keys.shape, [10, 2]) + + def test_uniform(self): + key = mx.random.key(0) + a = mx.random.uniform(key=key) + self.assertEqual(a.shape, []) + self.assertEqual(a.dtype, mx.float32) + + b = mx.random.uniform(key=key) + self.assertEqual(a.item(), b.item()) + + a = mx.random.uniform(shape=(2, 3)) + self.assertEqual(a.shape, [2, 3]) + + a = mx.random.uniform(shape=(1000,), low=-1, high=5) + self.assertTrue(mx.all((a > -1) < 5).item()) + + a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5) + self.assertTrue(mx.all((a > -1) < 5).item()) + + def test_normal(self): + key = mx.random.key(0) + a = mx.random.normal(key=key) + self.assertEqual(a.shape, []) + self.assertEqual(a.dtype, mx.float32) + + b = mx.random.normal(key=key) + self.assertEqual(a.item(), b.item()) + + a = mx.random.normal(shape=(2, 3)) + self.assertEqual(a.shape, [2, 3]) + + ## Generate in float16 or bfloat16 + for t in [mx.float16, mx.bfloat16]: + a = mx.random.normal(dtype=t) + self.assertEqual(a.dtype, t) + + def test_randint(self): + a = mx.random.randint(0, 1, []) + self.assertEqual(a.shape, []) + self.assertEqual(a.dtype, mx.int32) + + shape = [88] + low = mx.array(3) + high = mx.array(15) + + key = mx.random.key(0) + a = mx.random.randint(low, high, shape, key=key) + self.assertEqual(a.shape, shape) + self.assertEqual(a.dtype, mx.int32) + + # Check using the same key yields the same value + b = mx.random.randint(low, high, shape, key=key) + self.assertListEqual(a.tolist(), b.tolist()) + + shape = [3, 4] + low = mx.reshape(mx.array([0] * 3), [3, 1]) + high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4]) + + a = mx.random.randint(low, high, shape) + self.assertEqual(a.shape, shape) + + a = mx.random.randint(-10, 10, [1000, 1000]) + self.assertTrue(mx.all(-10 <= a).item() and mx.all(a < 10).item()) + + a = mx.random.randint(10, -10, [1000, 1000]) + self.assertTrue(mx.all(a == 10).item()) + + def test_bernoulli(self): + a = mx.random.bernoulli() + self.assertEqual(a.shape, []) + self.assertEqual(a.dtype, mx.bool_) + + a = mx.random.bernoulli(mx.array(0.5), [5]) + self.assertEqual(a.shape, [5]) + + a = mx.random.bernoulli(mx.array([2.0, -2.0])) + self.assertEqual(a.tolist(), [True, False]) + self.assertEqual(a.shape, [2]) + + p = mx.array([0.1, 0.2, 0.3]) + mx.reshape(p, [1, 3]) + x = mx.random.bernoulli(p, [4, 3]) + self.assertEqual(x.shape, [4, 3]) + + with self.assertRaises(ValueError): + mx.random.bernoulli(p, [2]) # Bad shape + + with self.assertRaises(ValueError): + mx.random.bernoulli(0, [2]) # Bad type + + def test_truncated_normal(self): + a = mx.random.truncated_normal(-2.0, 2.0) + self.assertEqual(a.size, 1) + self.assertEqual(a.dtype, mx.float32) + + a = mx.random.truncated_normal(mx.array([]), mx.array([])) + self.assertEqual(a.dtype, mx.float32) + self.assertEqual(a.size, 0) + + lower = mx.reshape(mx.array([-2.0, 0.0]), [1, 2]) + upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1]) + a = mx.random.truncated_normal(lower, upper) + + self.assertEqual(a.shape, [3, 2]) + self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item()) + + a = mx.random.truncated_normal(2.0, -2.0) + self.assertTrue(mx.all(a == 2.0).item()) + + a = mx.random.truncated_normal(-3.0, 3.0, [542, 399]) + self.assertEqual(a.shape, [542, 399]) + + lower = mx.array([-2.0, -1.0]) + higher = mx.array([1.0, 2.0, 3.0]) + with self.assertRaises(ValueError): + mx.random.truncated_normal(lower, higher) # Bad shape + + def test_gumbel(self): + samples = mx.random.gumbel(shape=(100, 100)) + self.assertEqual(samples.shape, [100, 100]) + self.assertEqual(samples.dtype, mx.float32) + mean = 0.5772 + # Std deviation of the sample mean is small (<0.02), + # so this test is pretty conservative + self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2) + + def test_categorical(self): + logits = mx.zeros((10, 20)) + self.assertEqual(mx.random.categorical(logits, -1).shape, [10]) + self.assertEqual(mx.random.categorical(logits, 0).shape, [20]) + self.assertEqual(mx.random.categorical(logits, 1).shape, [10]) + + out = mx.random.categorical(logits) + self.assertEqual(out.shape, [10]) + self.assertEqual(out.dtype, mx.uint32) + self.assertTrue(mx.max(out).item() < 20) + + out = mx.random.categorical(logits, 0, [5, 20]) + self.assertEqual(out.shape, [5, 20]) + self.assertTrue(mx.max(out).item() < 10) + + out = mx.random.categorical(logits, 1, num_samples=7) + self.assertEqual(out.shape, [10, 7]) + out = mx.random.categorical(logits, 0, num_samples=7) + self.assertEqual(out.shape, [20, 7]) + + with self.assertRaises(ValueError): + mx.random.categorical(logits, shape=[10, 5], num_samples=5) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_tree.py b/python/tests/test_tree.py new file mode 100644 index 000000000..fe6d8cdb7 --- /dev/null +++ b/python/tests/test_tree.py @@ -0,0 +1,26 @@ +import unittest + +import mlx.core as mx +import mlx.utils + +import mlx_tests + + +class TestTreeUtils(mlx_tests.MLXTestCase): + def test_tree_map(self): + tree = {"a": 0, "b": 1, "c": 2} + tree = mlx.utils.tree_map(lambda x: x + 1, tree) + + expected_tree = {"a": 1, "b": 2, "c": 3} + self.assertEqual(tree, expected_tree) + + def test_tree_flatten(self): + tree = [{"a": 1, "b": 2}, "c"] + vals = (1, 2, "c") + flat_tree = mlx.utils.tree_flatten(tree) + self.assertEqual(list(zip(*flat_tree))[1], vals) + self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py new file mode 100644 index 000000000..151fbc3ca --- /dev/null +++ b/python/tests/test_vmap.py @@ -0,0 +1,167 @@ +import unittest + +import mlx.core as mx + +import mlx_tests + + +class TestVmap(mlx_tests.MLXTestCase): + def test_basics(self): + # Can't vmap over scalars + with self.assertRaises(ValueError): + mx.vmap(mx.exp)(mx.array(1.0)) + + # Invalid input + with self.assertRaises(ValueError): + mx.vmap(mx.exp)("hello") + + # Invalid axes + with self.assertRaises(ValueError): + mx.vmap(mx.exp, in_axes="hello")(mx.array([0, 1])) + + with self.assertRaises(ValueError): + mx.vmap(mx.exp, in_axes=2)(mx.array([0, 1])) + + with self.assertRaises(ValueError): + mx.vmap(mx.exp, out_axes="hello")(mx.array([0, 1])) + + with self.assertRaises(ValueError): + mx.vmap(mx.exp, out_axes=2)(mx.array([0, 1])) + + def test_unary(self): + ops = [ + "abs", + "cos", + "erf", + "erfinv", + "exp", + "log", + "log1p", + "log2", + "log10", + "logical_not", + "negative", + "reciprocal", + "rsqrt", + "sigmoid", + "sign", + "sin", + "sqrt", + "square", + ] + ops = ["erfinv"] + for opname in ops: + with self.subTest(op=opname): + op = getattr(mx, opname) + x = mx.arange(5) + y = mx.vmap(op)(x) + self.assertTrue(mx.array_equal(y, op(x), equal_nan=True)) + + x = mx.arange(8).reshape(2, 4) + y = mx.vmap(op)(x) + self.assertTrue(mx.array_equal(y, op(x), equal_nan=True)) + + y = mx.vmap(op, in_axes=1, out_axes=1)(x) + self.assertTrue(mx.array_equal(y, op(x), equal_nan=True)) + + def test_binary(self): + ops = [ + "add", + "divide", + "equal", + "greater", + "greater_equal", + "less", + "less_equal", + "logaddexp", + "maximum", + "minimum", + "multiply", + "power", + "subtract", + ] + for opname in ops: + with self.subTest(op=opname): + op = getattr(mx, opname) + x = mx.random.uniform(shape=(5,)) + y = mx.random.uniform(shape=(5,)) + out = mx.vmap(op)(x, y) + self.assertTrue(mx.array_equal(out, op(x, y))) + + x = mx.random.uniform(shape=(2, 4)) + y = mx.random.uniform(shape=(2, 4)) + out = mx.vmap(op)(x, y) + self.assertTrue(mx.array_equal(out, op(x, y))) + + out = mx.vmap(op, in_axes=(0, 0), out_axes=0)(x, y) + self.assertTrue(mx.array_equal(out, op(x, y))) + + y = mx.random.uniform(shape=(4, 2)) + out = mx.vmap(op, in_axes=(0, 1), out_axes=0)(x, y) + self.assertTrue(mx.array_equal(out, op(x, y.T))) + + out = mx.vmap(op, in_axes=(0, 1), out_axes=1)(x, y) + self.assertTrue(mx.array_equal(out, op(x, y.T).T)) + + def test_tree(self): + def my_fun(tree): + return (tree["a"] + tree["b"][0]) * tree["b"][1] + + tree = { + "a": mx.random.uniform(shape=(2, 4)), + "b": ( + mx.random.uniform(shape=(2, 4)), + mx.random.uniform(shape=(2, 4)), + ), + } + out = mx.vmap(my_fun)(tree) + expected = my_fun(tree) + self.assertTrue(mx.array_equal(out, my_fun(tree))) + + with self.assertRaises(ValueError): + mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree) + + with self.assertRaises(ValueError): + mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree) + + out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree) + self.assertTrue(mx.array_equal(out, my_fun(tree))) + + out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree) + self.assertTrue(mx.array_equal(out, my_fun(tree))) + + tree = { + "a": mx.random.uniform(shape=(2, 4)), + "b": ( + mx.random.uniform(shape=(4, 2)), + mx.random.uniform(shape=(4, 2)), + ), + } + out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree) + expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T + self.assertTrue(mx.array_equal(out, expected)) + + def my_fun(x, y): + return {"a": x + y, "b": x * y} + + x = mx.random.uniform(shape=(2, 4)) + y = mx.random.uniform(shape=(2, 4)) + out = mx.vmap(my_fun, in_axes=0, out_axes=0)(x, y) + expected = my_fun(x, y) + self.assertTrue(mx.array_equal(out["a"], expected["a"])) + self.assertTrue(mx.array_equal(out["b"], expected["b"])) + + with self.assertRaises(ValueError): + mx.vmap(my_fun, in_axes=0, out_axes=(0, 1))(x, y) + + with self.assertRaises(ValueError): + mx.vmap(my_fun, in_axes=0, out_axes={"a": 0, "c": 1})(x, y) + + out = mx.vmap(my_fun, in_axes=0, out_axes={"a": 1, "b": 0})(x, y) + expected = my_fun(x, y) + self.assertTrue(mx.array_equal(out["a"].T, expected["a"])) + self.assertTrue(mx.array_equal(out["b"], expected["b"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..c4d3569cb --- /dev/null +++ b/setup.py @@ -0,0 +1,127 @@ +import os +import re +import subprocess +import sys +import sysconfig +from pathlib import Path + +from setuptools import Extension, setup, find_namespace_packages +from setuptools.command.build_ext import build_ext + + +# A CMakeExtension needs a sourcedir instead of a file list. +# The name must be the _single_ output extension from the CMake build. +# If you need multiple extensions, see scikit-build. +class CMakeExtension(Extension): + def __init__(self, name: str, sourcedir: str = "") -> None: + super().__init__(name, sources=[]) + self.sourcedir = os.fspath(Path(sourcedir).resolve()) + + +class CMakeBuild(build_ext): + def build_extension(self, ext: CMakeExtension) -> None: + # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ + ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call] + extdir = ext_fullpath.parent.resolve() + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + # CMake lets you override the generator - we need to check this. + # Can be set with Conda-Build, for example. + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + f"-DCMAKE_INSTALL_PREFIX={extdir}{os.sep}", + f"-DCMAKE_BUILD_TYPE={cfg}", + "-DBUILD_SHARED_LIBS=ON", + "-DMLX_BUILD_PYTHON_BINDINGS=ON", + "-DMLX_BUILD_TESTS=OFF", + "-DMLX_BUILD_BENCHMARKS=OFF", + "-DMLX_BUILD_EXAMPLES=OFF", + f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}", + ] + build_args = [] + # Adding CMake arguments set as environment variable + # (needed e.g. to build for ARM OSx on conda-forge) + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + # Pass version to C++ + cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined] + + if sys.platform.startswith("darwin"): + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) + if archs: + cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += [f"-j{self.parallel}"] + + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + subprocess.run( + ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True + ) + subprocess.run( + ["cmake", "--build", ".", "--target", "install", *build_args], + cwd=build_temp, + check=True, + ) + + # Make sure to copy mlx.metallib for inplace builds + def run(self): + super().run() + + # Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102 + if self.inplace: + for ext in self.extensions: + if ext.name == "mlx.core": + # Resolve inplace package dir + build_py = self.get_finalized_command("build_py") + inplace_file, regular_file = self._get_inplace_equivalent( + build_py, ext + ) + + inplace_dir = str(Path(inplace_file).parent.resolve()) + regular_dir = str(Path(regular_file).parent.resolve()) + + self.copy_tree(regular_dir, inplace_dir) + + +# The information here can also be placed in setup.cfg - better separation of +# logic and declaration, and simpler if you include description/version in a file. +if __name__ == "__main__": + packages = find_namespace_packages( + where="python", exclude=["src", "tests", "tests.*"] + ) + package_dir = {"": "python"} + package_data = {"mlx": ["lib/*", "include/*", "share/*"]} + setup( + name="mlx", + version="0.0.2", + author="MLX Contributors", + author_email="mlx@group.apple.com", + description="A framework for machine learning on Apple Silicon.", + long_description="", + packages=packages, + package_dir=package_dir, + package_data=package_data, + include_package_data=True, + ext_modules=[CMakeExtension("mlx.core")], + cmdclass={"build_ext": CMakeBuild}, + zip_safe=False, + python_requires=">=3.7", + ) diff --git a/tests/allocator_tests.cpp b/tests/allocator_tests.cpp new file mode 100644 index 000000000..70199eb46 --- /dev/null +++ b/tests/allocator_tests.cpp @@ -0,0 +1,41 @@ +#include + +#include "doctest/doctest.h" + +#include "mlx/allocator.h" + +using namespace mlx::core; + +TEST_CASE("test simple allocations") { + { + auto buffer = allocator::malloc(sizeof(float)); + auto fptr = static_cast(buffer.raw_ptr()); + *fptr = 0.5f; + CHECK_EQ(*fptr, 0.5f); + allocator::free(buffer); + } + + { + auto buffer = allocator::malloc(128 * sizeof(int)); + int* ptr = static_cast(buffer.raw_ptr()); + for (int i = 0; i < 128; ++i) { + ptr[i] = i; + } + allocator::free(buffer); + } + + { + auto buffer = allocator::malloc(0); + allocator::free(buffer); + } +} + +TEST_CASE("test large allocations") { + size_t size = 1 << 30; + for (int i = 0; i < 100; ++i) { + auto buffer = allocator::malloc(size); + allocator::free(buffer); + } + // Shouldn't be able to allocate an exabyte anytime soon. + CHECK_THROWS_AS(allocator::malloc(1ull << 60), std::runtime_error); +} diff --git a/tests/array_tests.cpp b/tests/array_tests.cpp new file mode 100644 index 000000000..c435b1db4 --- /dev/null +++ b/tests/array_tests.cpp @@ -0,0 +1,589 @@ +#include + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test array basics") { + // Scalar + array x(1.0); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.ndim(), 0); + CHECK_EQ(x.shape(), std::vector{}); + CHECK_THROWS_AS(x.shape(0), std::out_of_range); + CHECK_THROWS_AS(x.shape(-1), std::out_of_range); + CHECK_EQ(x.strides(), std::vector{}); + CHECK_EQ(x.itemsize(), sizeof(float)); + CHECK_EQ(x.nbytes(), sizeof(float)); + CHECK_EQ(x.dtype(), float32); + CHECK_EQ(x.item(), 1.0); + + // Scalar with specified type + x = array(1, float32); + CHECK_EQ(x.dtype(), float32); + CHECK_EQ(x.item(), 1.0); + + // Scalar with specified type + x = array(1, bool_); + CHECK_EQ(x.dtype(), bool_); + CHECK_EQ(x.itemsize(), sizeof(bool)); + CHECK_EQ(x.nbytes(), sizeof(bool)); + CHECK_EQ(x.item(), true); + + // Check shaped arrays + x = array({1.0}); + CHECK_EQ(x.dtype(), float32); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.ndim(), 1); + CHECK_EQ(x.shape(), std::vector{1}); + CHECK_EQ(x.shape(0), 1); + CHECK_EQ(x.shape(-1), 1); + CHECK_THROWS_AS(x.shape(1), std::out_of_range); + CHECK_THROWS_AS(x.shape(-2), std::out_of_range); + CHECK_EQ(x.strides(), std::vector{1}); + CHECK_EQ(x.item(), 1.0); + + // Check empty array + x = array({}); + CHECK_EQ(x.size(), 0); + CHECK_EQ(x.dtype(), float32); + CHECK_EQ(x.itemsize(), sizeof(float)); + CHECK_EQ(x.nbytes(), 0); + CHECK_THROWS_AS(x.item(), std::invalid_argument); + + x = array({1.0, 1.0}); + CHECK_EQ(x.size(), 2); + CHECK_EQ(x.shape(), std::vector{2}); + CHECK_EQ(x.itemsize(), sizeof(float)); + CHECK_EQ(x.nbytes(), x.itemsize() * x.size()); + + // Accessing item in non-scalar array throws + CHECK_THROWS_AS(x.item(), std::invalid_argument); + + x = array({1.0, 1.0, 1.0}, {1, 3}); + CHECK(x.size() == 3); + CHECK(x.shape() == std::vector{1, 3}); + CHECK(x.strides() == std::vector{3, 1}); + + // Test wrong size/shapes throw: + CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument); + CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 4}), std::invalid_argument); + CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 2}), std::invalid_argument); + + // Test array ids work as expected + x = array(1.0); + auto y = x; + CHECK_EQ(y.id(), x.id()); + array z(2.0); + CHECK_NE(z.id(), x.id()); + z = x; + CHECK_EQ(z.id(), x.id()); + + // Array creation from pointer + float data[] = {0.0, 1.0, 2.0, 3.0}; + x = array(data, {4}); + CHECK_EQ(x.dtype(), float32); + CHECK(array_equal(x, array({0.0, 1.0, 2.0, 3.0})).item()); + + // Array creation from vectors + { + std::vector data = {0, 1, 2, 3}; + x = array(data.begin(), {4}); + CHECK_EQ(x.dtype(), int32); + CHECK(array_equal(x, array({0, 1, 2, 3})).item()); + } + + { + std::vector data = {false, true, false, true}; + x = array(data.begin(), {4}); + CHECK_EQ(x.dtype(), bool_); + CHECK(array_equal(x, array({false, true, false, true})).item()); + } +} + +TEST_CASE("test array types") { +#define basic_dtype_test(T, mlx_type) \ + T val = 42; \ + array x(val); \ + CHECK_EQ(x.dtype(), mlx_type); \ + CHECK_EQ(x.item(), val); \ + x = array({val, val}); \ + CHECK_EQ(x.dtype(), mlx_type); + + // bool_ + { + array x(true); + CHECK_EQ(x.dtype(), bool_); + CHECK_EQ(x.item(), true); + + x = array({true, false}); + CHECK_EQ(x.dtype(), bool_); + + x = array({true, false}, float32); + CHECK_EQ(x.dtype(), float32); + CHECK(array_equal(x, array({1.0f, 0.0f})).item()); + } + + // uint8 + { basic_dtype_test(uint8_t, uint8); } + + // uint16 + { basic_dtype_test(uint16_t, uint16); } + + // uint32 + { basic_dtype_test(uint32_t, uint32); } + + // uint64 + { basic_dtype_test(uint64_t, uint64); } + + // int8 + { basic_dtype_test(int8_t, int8); } + + // int16 + { basic_dtype_test(int16_t, int16); } + + // int32 + { basic_dtype_test(int32_t, int32); } + + // int64 + { basic_dtype_test(int64_t, int64); } + + // float16 + { basic_dtype_test(float16_t, float16); } + + // float32 + { basic_dtype_test(float, float32); } + + // bfloat16 + { basic_dtype_test(bfloat16_t, bfloat16); } + + // uint32 + { + uint32_t val = UINT_MAX; + array x(val); + CHECK_EQ(x.dtype(), uint32); + CHECK_EQ(x.item(), val); + + x = array({1u, 2u}); + CHECK_EQ(x.dtype(), uint32); + } + + // int32 + { + array x(-1); + CHECK_EQ(x.dtype(), int32); + CHECK_EQ(x.item(), -1); + + x = array({-1, 2}); + CHECK_EQ(x.dtype(), int32); + + std::vector data{0, 1, 2}; + x = array(data.data(), {static_cast(data.size())}, bool_); + CHECK_EQ(x.dtype(), bool_); + CHECK(array_equal(x, array({false, true, true})).item()); + } + + // int64 + { + int64_t val = static_cast(INT_MIN) - 1; + array x(val); + CHECK_EQ(x.dtype(), int64); + CHECK_EQ(x.item(), val); + + x = array({val, val}); + CHECK_EQ(x.dtype(), int64); + } + + // float32 + { + array x(3.14f); + CHECK_EQ(x.dtype(), float32); + CHECK_EQ(x.item(), 3.14f); + + x = array(1.25); + CHECK_EQ(x.dtype(), float32); + CHECK_EQ(x.item(), 1.25f); + + x = array({1.0f, 2.0f}); + CHECK_EQ(x.dtype(), float32); + + x = array({1.0, 2.0}); + CHECK_EQ(x.dtype(), float32); + + std::vector data{1.0, 2.0, 4.0}; + x = array(data.data(), {static_cast(data.size())}); + CHECK_EQ(x.dtype(), float32); + CHECK(array_equal(x, array({1.0f, 2.0f, 4.0f})).item()); + } + + // complex64 + { + complex64_t v = {1.0f, 1.0f}; + array x(v); + CHECK_EQ(x.dtype(), complex64); + CHECK_EQ(x.item(), v); + + array y(std::complex{1.0f, 1.0f}); + CHECK_EQ(x.dtype(), complex64); + CHECK_EQ(x.item(), v); + } + +#undef basic_dtype_test + +#define basic_dtype_str_test(s, dtype) \ + CHECK_EQ(s, dtype_to_array_protocol(dtype)); \ + CHECK_EQ(dtype_from_array_protocol(s), dtype); + + // To and from str + { + basic_dtype_str_test("|b1", bool_); + basic_dtype_str_test("|u1", uint8); + basic_dtype_str_test("{1, 1}); + CHECK_EQ(y.data_size(), 1); + CHECK_EQ(y.flags().contiguous, true); + CHECK_EQ(y.flags().row_contiguous, true); + CHECK_EQ(y.flags().col_contiguous, true); + + x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4}); + y = slice(x, {0, 0}, {1, 4}, {1, 2}); + eval(y); + CHECK_EQ(y.shape(), std::vector{1, 2}); + CHECK_EQ(y.flags().contiguous, false); + CHECK_EQ(y.flags().row_contiguous, false); + CHECK_EQ(y.flags().col_contiguous, false); + + x = broadcast_to(array(1.0f), {4, 10}); + y = slice(x, {0, 0}, {4, 10}, {2, 2}); + eval(y); + CHECK_EQ(y.shape(), std::vector{2, 5}); + CHECK_EQ(y.data_size(), 1); + CHECK_EQ(y.flags().contiguous, true); + CHECK_EQ(y.flags().row_contiguous, false); + CHECK_EQ(y.flags().col_contiguous, false); + + x = broadcast_to(array({1.0f, 2.0f}), {4, 2}); + y = slice(x, {0, 0}, {1, 2}, {1, 1}); + eval(y); + CHECK_EQ(y.data_size(), 2); + CHECK_EQ(y.flags().contiguous, true); + CHECK_EQ(y.flags().row_contiguous, true); + CHECK_EQ(y.flags().col_contiguous, true); + + y = slice(x, {1, 0}, {2, 2}, {1, 1}); + eval(y); + CHECK_EQ(y.data_size(), 2); + CHECK_EQ(y.flags().contiguous, true); + CHECK_EQ(y.flags().row_contiguous, true); + CHECK_EQ(y.flags().col_contiguous, true); + + x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}); + y = slice(x, {0, 0}, {2, 2}, {1, 1}); + eval(y); + CHECK_EQ(y.data_size(), 4); + CHECK_EQ(y.flags().contiguous, true); + CHECK_EQ(y.flags().row_contiguous, true); + CHECK_EQ(y.flags().col_contiguous, false); + + y = slice(transpose(x), {0, 0}, {2, 2}, {1, 1}); + eval(y); + CHECK_EQ(y.data_size(), 4); + CHECK_EQ(y.flags().contiguous, true); + CHECK_EQ(y.flags().row_contiguous, false); + CHECK_EQ(y.flags().col_contiguous, true); + + x = ones({2, 4}); + auto out = split(x, 2); + eval(out); + for (auto y : out) { + CHECK_EQ(y.data_size(), 4); + CHECK_EQ(y.flags().contiguous, true); + CHECK_EQ(y.flags().row_contiguous, true); + CHECK_EQ(y.flags().col_contiguous, true); + } + out = split(x, 4, 1); + eval(out); + for (auto y : out) { + CHECK_EQ(y.flags().contiguous, false); + CHECK_EQ(y.flags().row_contiguous, false); + CHECK_EQ(y.flags().col_contiguous, false); + } +} + +TEST_CASE("test array iteration") { + // Dim 0 arrays + auto arr = array(1); + CHECK_THROWS(arr.begin()); + + // Iterated arrays are read only + CHECK(std::is_const_v); + + arr = array({1, 2, 3, 4, 5}); + int i = 0; + for (auto a : arr) { + i++; + CHECK_EQ(a.item(), i); + } + CHECK_EQ(i, 5); + + arr = array({1, 2, 3, 4}, {2, 2}); + CHECK(array_equal(*arr.begin(), array({1, 2})).item()); + CHECK(array_equal(*(arr.begin() + 1), array({3, 4})).item()); + CHECK_EQ(arr.begin() + 2, arr.end()); +} + +TEST_CASE("test array shared buffer") { + std::vector shape = {2, 2}; + int n_elem = shape[0] * shape[1]; + + allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float)); + void* buf_b_ptr = buf_b.raw_ptr(); + float* float_buf_b = (float*)buf_b_ptr; + + for (int i = 0; i < n_elem; i++) { + float_buf_b[i] = 2.; + } + + CHECK_EQ(float_buf_b[0], ((float*)buf_b_ptr)[0]); + + auto deleter = [float_buf_b](allocator::Buffer buf) { + CHECK_EQ(float_buf_b, (float*)buf.raw_ptr()); + CHECK_EQ(float_buf_b[0], ((float*)buf.raw_ptr())[0]); + allocator::free(buf); + }; + + array a = ones(shape, float32); + array b = array(buf_b, shape, float32, deleter); + + eval(a + b); +} diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp new file mode 100644 index 000000000..7c7245cd7 --- /dev/null +++ b/tests/autograd_tests.cpp @@ -0,0 +1,1192 @@ +#include +#include +#include +#include +#include +#include "doctest/doctest.h" + +#include "mlx/graph_utils.h" +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test stop gradient") { + auto x = zeros({5, 5}); + auto y = stop_gradient(x); + CHECK(array_equal(y, zeros({5, 5})).item()); + + x = zeros({5, 5}, int32); + y = stop_gradient(x); + CHECK_EQ(y.dtype(), int32); + CHECK(array_equal(y, zeros({5, 5}, int32)).item()); + + { + auto fun = [](array input) { return stop_gradient(add(input, ones({2}))); }; + auto vfun = vmap(fun); + auto out = vfun(ones({3, 2})); + CHECK(array_equal(out, full({3, 2}, 2.0)).item()); + } + + { + auto fun = [](array input) { return add(stop_gradient(input), ones({2})); }; + auto vfun = vmap(fun); + auto out = vfun(ones({3, 2})); + CHECK(array_equal(out, full({3, 2}, 2.0)).item()); + } + + { + auto x = array(1.); + auto fun = [](array in) { return stop_gradient(add(in, in)); }; + auto out = vjp(fun, x, array(1.)).second; + CHECK(array_equal(out, array(0.)).item()); + + out = jvp(fun, x, array(1.)).second; + CHECK(array_equal(out, array(0.)).item()); + } + + { + auto x = array(1.); + auto fun = [](array in) { return add(in, stop_gradient(in)); }; + auto out = vjp(fun, x, array(1.)).second; + CHECK(array_equal(out, array(1.)).item()); + + out = jvp(fun, x, array(1.)).second; + CHECK(array_equal(out, array(1.)).item()); + } + + { + auto x = array(1.); + auto fun = [](array in) { + for (int i = 0; i < 10; ++i) { + in = add(in, in); + } + return stop_gradient(in); + }; + { + auto out = vjp(fun, x, array(1.)).second; + std::ostringstream g_ss; + print_graph(g_ss, out); + auto g_str = g_ss.str(); + auto count = std::count(g_str.begin(), g_str.end(), '\n'); + CHECK(count < 5); + } + { + auto out = jvp(fun, x, array(1.)).second; + std::ostringstream g_ss; + print_graph(g_ss, out); + auto g_str = g_ss.str(); + auto count = std::count(g_str.begin(), g_str.end(), '\n'); + CHECK(count < 5); + } + } +} + +TEST_CASE("test jvp") { + { + auto fun = [](const std::vector& inputs) { + return std::vector{add(inputs[0], inputs[1])}; + }; + auto x = array(1.0f); + auto y = array(1.0f); + auto [out, dout] = jvp(fun, {x, y}, {array(1.0f), array(3.0f)}); + CHECK_EQ(out[0].item(), 2.0f); + CHECK_EQ(dout[0].item(), 4.0f); + } + + // Evaling in function without graph retention throws + { + auto fun = [](const array& x) { + auto y = 3 * x; + eval(y); + return 2 * y; + }; + CHECK_THROWS(jvp(fun, array(1.0f), array(1.0f))); + + // Ok with graph retention + auto fun1 = [](const array& x) { + auto y = 3 * x; + eval({y}, true); + return 2 * y; + }; + CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item(), 6.0f); + } + + // Only one argument + { + auto x = array(1.0f); + auto fun = [x](array in) { return add(x, in); }; + auto y = array(1.0f); + auto out = jvp(fun, y, array(3.0f)).second; + CHECK_EQ(out.item(), 3.0f); + } + + // Input also in capture clause + { + auto x = array(1.0f); + auto fun = [x](array in) { return in + x; }; + auto out = jvp(fun, x, array(1.0f)).second; + CHECK_EQ(out.item(), 1.0f); + } + + // Throws on incorrectly shaped inputs + { + auto fun = [](array in) { return add(in, in); }; + CHECK_THROWS_AS(jvp(fun, array(1), array({1, 1})), std::invalid_argument); + } + + // Throws on wrong number of inputs + { + auto fun = [](std::vector inputs) { + return std::vector{inputs[0], inputs[1]}; + }; + CHECK_THROWS_AS( + jvp(fun, {array(1), array(1)}, {array(1)}), std::invalid_argument); + } + + // No dependence between input and output + { + auto fun = [](array in) { return array({1.0, 1.0}); }; + auto out = jvp(fun, array(1.0f), array(1.0f)).second; + CHECK(array_equal(out, zeros({2})).item()); + } +} + +TEST_CASE("test vjp") { + { + auto x = array(1.0f); + auto y = array(1.0f); + auto fun = [y](array in) { return add(in, y); }; + auto [out, dout] = vjp(fun, x, array(1.0f)); + CHECK_EQ(out.item(), 2.0f); + CHECK_EQ(dout.item(), 1.0f); + } + + { + auto x = array(1.0f); + auto fun = [](array in) { return in + in + in; }; + auto out = vjp(fun, x, array(1.0f)).second; + CHECK_EQ(out.item(), 3.0f); + out = vjp(fun, x, array(2.)).second; + CHECK_EQ(out.item(), 6.0f); + } + + // Input also in capture clause + { + auto x = array(1.0f); + auto fun = [x](array in) { return in + x; }; + auto out = vjp(fun, x, array(1.0f)).second; + CHECK_EQ(out.item(), 1.0f); + } + + // Throws on incorrectly shaped outputs + { + auto fun = [](array in) { return add(in, in); }; + CHECK_THROWS_AS(vjp(fun, zeros({1}), zeros({2})), std::invalid_argument); + } + + // Throws on wrong number of outputs + { + auto fun = [](std::vector inputs) { + return std::vector{inputs[0], inputs[0]}; + }; + CHECK_THROWS_AS( + vjp(fun, {zeros({1})}, {zeros({2})}), std::invalid_argument); + } + + // No dependence between input and output + { + auto fun = [](array in) { return array(1.); }; + auto out = vjp(fun, zeros({2}), array(1.)).second; + CHECK(array_equal(out, zeros({2})).item()); + } + + // Handles multiple outputs + { + auto x = array(1.); + auto y = array(2.); + auto z = array(3.); + auto fun = [](const std::vector& in) { + return std::vector{in[0] * in[1], in[1] * in[2]}; + }; + auto out = vjp(fun, {x, y, z}, {array(2.), array(3.)}).second; + CHECK_EQ(out.size(), 3); + CHECK_EQ(out[0].item(), 2.0f * 2.0f); + CHECK_EQ(out[1].item(), 1.0f * 2.0f + 3.0f * 3.0f); + CHECK_EQ(out[2].item(), 3.0f * 2.0f); + } +} + +TEST_CASE("test grad") { + { + auto x = array(1.0); + auto fun = [](array in) { return in + 1; }; + auto [y, dfdx] = value_and_grad(fun)(x); + CHECK_EQ(y.item(), 2.0f); + CHECK_EQ(dfdx.item(), 1.0f); + auto [z, d2fdx2] = value_and_grad(grad(fun))(x); + CHECK_EQ(z.item(), 1.0f); + CHECK_EQ(d2fdx2.item(), 0.0f); + } + + { + auto x = array(1.); + auto fun = [](array in) { return add(in, array(1.)); }; + auto dfdx = grad(fun); + CHECK(array_equal(dfdx(x), array(1.)).item()); + auto d2fdx2 = grad(grad(fun)); + CHECK(array_equal(d2fdx2(x), array(0.)).item()); + } + + { + auto x = array(1.); + auto expfn = [](array input) { return exp(input); }; + auto dfdx = grad(expfn); + CHECK_EQ(dfdx(x).item(), std::exp(1.0f)); + auto d2fdx2 = grad(grad(expfn)); + CHECK_EQ(d2fdx2(x).item(), std::exp(1.0f)); + auto d3fdx3 = grad(grad(grad(expfn))); + CHECK_EQ(d3fdx3(x).item(), std::exp(1.0f)); + } + + { + // Evaluating in the middle of the grad function throws + // as it breaks the graph + auto fn = [](array x) { + x = x + 2.0f; + eval(x); + return square(x); + }; + CHECK_THROWS(grad(fn)(array(1.0f))); + + // Ok since the output is independent of y + auto y = ones({3, 3}); + auto fn1 = [y](array x) { + x = x + 2.0f; + eval(y); + return square(x); + }; + auto dfdx = grad(fn1)(array(1.0f)); + CHECK_EQ(dfdx.item(), 6.0f); + + // Retain the graph to avoid breaking it + auto fn2 = [](array x) { + x = x + 2.0f; + eval({x}, true); + return square(x); + }; + dfdx = grad(fn2)(array(1.0f)); + CHECK_EQ(dfdx.item(), 6.0f); + } + + // Control flow in grad computation + { + auto fn = [](array x) { + if (x.item(true) > 1) { + return square(x); + } else { + return 4 * x; + } + }; + + auto dfdx = grad(fn)(array(0.5f)); + CHECK_EQ(dfdx.item(), 4.0f); + + dfdx = grad(fn)(array(1.5f)); + CHECK_EQ(dfdx.item(), 3.0f); + } + + // Grad with multiple inputs + { + auto fn = [](std::vector inputs) { return inputs[0] * inputs[1]; }; + auto x = array(2.0f); + auto y = array(3.0f); + + auto [value, grads] = value_and_grad(fn)({x, y}); + CHECK_EQ(value.item(), 6.0f); + CHECK_EQ(grads[0].item(), 3.0f); + + auto dfdx = grad(fn)({x, y})[0]; + CHECK_EQ(dfdx.item(), 3.0f); + + auto dfdy = grad(fn, 1)({x, y})[0]; + CHECK_EQ(dfdy.item(), 2.0f); + + // Negative indexing + dfdy = grad(fn, -1)({x, y})[0]; + CHECK_EQ(dfdy.item(), 2.0f); + + grads = grad(fn, {0, 1})({x, y}); + CHECK_EQ(grads[0].item(), 3.0f); + CHECK_EQ(grads[1].item(), 2.0f); + + CHECK_THROWS_AS( + grad(fn, std::vector{})({x, y}), std::invalid_argument); + CHECK_THROWS_AS(grad(fn, {0, 1, 2})({x, y}), std::invalid_argument); + CHECK_THROWS_AS(grad(fn, {0, 0})({x, y}), std::invalid_argument); + CHECK_THROWS_AS(grad(fn, -3)({x, y}), std::invalid_argument); + } +} + +TEST_CASE("test creation grads") { + // Test astype + { + auto fn = [](array a) { return astype(a, int32); }; + auto x = ones({4, 4}, float32); + auto out = vjp(fn, x, full({4, 4}, 2, int32)).second; + CHECK_EQ(out.dtype(), float32); + CHECK(array_equal(out, full({4, 4}, 2.0f)).item()); + + out = jvp(fn, x, full({4, 4}, 2, float32)).second; + CHECK_EQ(out.dtype(), int32); + CHECK(array_equal(out, full({4, 4}, 2, int32)).item()); + } + + // Test full + { + auto full_fn = [](array a) { return full({5, 5, 2}, a); }; + auto x = ones({2}, float32); + auto out = vjp(full_fn, x, full({5, 5, 2}, 2.0f)).second; + CHECK(array_equal(out, array({50.0f, 50.0f})).item()); + + out = jvp(full_fn, x, array({3.0f, 3.0f})).second; + CHECK(array_equal(out, full({5, 5, 2}, 3.0f)).item()); + } +} + +TEST_CASE("test op vjps") { + // Test abs + { + auto out = vjp([](array in) { return abs(in); }, array(-5.0f), array(1.0f)); + CHECK_EQ(out.second.item(), -1.0f); + } + + // Test sign + { + auto out = + vjp([](array in) { return sign(in); }, array(-5.0f), array(10.0f)); + CHECK_EQ(out.second.item(), 0.0f); + } + + // Test negate + { + auto out = vjp([](array in) { return -in; }, array(1.0), array(2.0)); + CHECK(array_equal(out.second, array(-2.)).item()); + } + + // Test square + { + auto out = + vjp([](array in) { return square(in); }, array(2.0f), array(3.0f)); + CHECK_EQ(out.second.item(), 12.0f); + } + + // Test sqrt + { + auto out = vjp( + [](array in) { return mlx::core::sqrt(in); }, array(4.0f), array(8.0f)); + CHECK_EQ(out.second.item(), 2.0f); + } + + // Test rsqrt + { + auto out = + vjp([](array in) { return rsqrt(in); }, array(4.0f), array(8.0f)); + CHECK_EQ(out.second.item(), -0.5f); + } + + // Test exp + { + auto out = vjp([](array in) { return exp(in); }, array(1.0f), array(2.0f)); + CHECK_EQ(out.second.item(), 2.0f * std::exp(1.0f)); + } + + // Test sin + { + auto out = + vjp([](array input) { return sin(input); }, array(1.0f), array(1.0f)); + CHECK(out.second.item() == doctest::Approx(std::cos(1.0f))); + } + + // Test cos + { + auto out = + vjp([](array input) { return cos(input); }, array(1.0f), array(1.0f)); + CHECK(out.second.item() == doctest::Approx(-std::sin(1.0f))); + } + + // Test log + { + auto out = vjp([](array in) { return log(in); }, array(2.0f), array(1.0f)); + CHECK_EQ(out.second.item(), 0.5f); + + out = vjp([](array in) { return log(in); }, array(2.0f), array(2.0f)); + CHECK_EQ(out.second.item(), 1.0f); + } + + // Test log1p + { + auto out = + vjp([](array in) { return log1p(in); }, array(1.0f), array(1.0f)); + CHECK_EQ(out.second.item(), 0.5f); + + out = vjp([](array in) { return log1p(in); }, array(1.0f), array(2.0f)); + CHECK_EQ(out.second.item(), 1.0f); + } + + constexpr auto inf = std::numeric_limits::infinity(); + + // Test erf + { + auto out = vjp([](array in) { return erf(in); }, array(inf), array(1.0f)); + CHECK_EQ(out.second.item(), 0.0f); + + out = vjp([](array in) { return erf(in); }, array(-inf), array(2.0f)); + CHECK_EQ(out.second.item(), 0.0f); + + out = vjp([](array in) { return erf(in); }, array(0.0f), array(1.0f)); + CHECK_EQ(out.second.item(), static_cast(M_2_SQRTPI)); + } + + // Test erfinv + { + auto out = + vjp([](array in) { return erfinv(in); }, array(1.0f), array(1.0f)); + CHECK_EQ(out.second.item(), inf); + + out = vjp([](array in) { return erfinv(in); }, array(-1.0f), array(2.0f)); + CHECK_EQ(out.second.item(), inf); + + out = vjp([](array in) { return erfinv(in); }, array(0.0f), array(1.0f)); + CHECK_EQ(out.second.item(), static_cast(1.0 / M_2_SQRTPI)); + } + + // Test sigmoid + { + auto out = + vjp([](array in) { return sigmoid(in); }, array(0.0f), array(1.0f)); + CHECK_EQ(out.second.item(), 0.25f); + + out = vjp([](array in) { return sigmoid(in); }, array(0.0f), array(2.0f)); + CHECK_EQ(out.second.item(), 0.5f); + } + + // Test add + { + auto fun = [](std::vector inputs) { + return std::vector{inputs[0] + inputs[1]}; + }; + auto out = vjp(fun, {array(1.0), array(2.0)}, {array(3.0)}).second; + CHECK_EQ(out[0].item(), 3.0); + CHECK_EQ(out[1].item(), 3.0); + + // Check with broadcasting + out = vjp(fun, {ones({3, 1}), ones({1, 2})}, {full({3, 2}, 2.0)}).second; + CHECK(array_equal(out[0], full({3, 1}, 4.0)).item()); + CHECK(array_equal(out[1], full({1, 2}, 6.0)).item()); + } + + // Test subtract + { + auto fun = [](std::vector inputs) { + return std::vector{inputs[0] - inputs[1]}; + }; + auto out = vjp(fun, {array(1.0), array(2.0)}, {array(3.0)}).second; + CHECK_EQ(out[0].item(), 3.0); + CHECK_EQ(out[1].item(), -3.0); + + // Check with broadcasting + out = vjp(fun, {ones({3, 1}), ones({1, 2})}, {ones({3, 2})}).second; + CHECK(array_equal(out[0], full({3, 1}, 2.0)).item()); + CHECK(array_equal(out[1], full({1, 2}, -3.0)).item()); + } + + // Test multiply + { + auto fun = [](std::vector inputs) { + return std::vector{inputs[0] * inputs[1]}; + }; + auto out = vjp(fun, {array(4.0f), array(2.0f)}, {array(3.0f)}).second; + CHECK_EQ(out[0].item(), 6.0f); + CHECK_EQ(out[1].item(), 12.0f); + + // Check with broadcasting + out = vjp(fun, {full({3, 1}, 2.0f), full({1, 2}, 4.0f)}, {ones({3, 2})}) + .second; + CHECK(array_equal(out[0], full({3, 1}, 8.0f)).item()); + CHECK(array_equal(out[1], full({1, 2}, 6.0)).item()); + } + + // Test divide + { + auto fun = [](std::vector inputs) { + return std::vector{inputs[0] / inputs[1]}; + }; + auto out = vjp(fun, {array(4.0f), array(2.0f)}, {array(1.0f)}).second; + CHECK_EQ(out[0].item(), 0.5f); + CHECK_EQ(out[1].item(), -1.0f); + + // Check with broadcasting + out = vjp(fun, {full({3, 1}, 4.0f), full({1, 2}, 2.0f)}, {ones({3, 2})}) + .second; + CHECK(array_equal(out[0], full({3, 1}, 1.0f)).item()); + CHECK(array_equal(out[1], full({1, 2}, -3.0f)).item()); + } + + // Test maximum + { + auto fun = [](std::vector inputs) { + return std::vector{maximum(inputs[0], inputs[1])}; + }; + auto out = vjp(fun, {array(5.0f), array(2.0f)}, {array(2.0f)}).second; + CHECK_EQ(out[0].item(), 2.0f); + CHECK_EQ(out[1].item(), 0.0f); + + out = vjp(fun, {array(2.0f), array(2.0f)}, {array(1.0f)}).second; + auto out_a = out[0].item(); + auto out_b = out[1].item(); + // When inputs are equal at most one gradient is nonzero + CHECK( + ((out_a == 1.0f && out_b == 0.0f) || (out_a == 0.0f && out_b == 1.0f))); + } + + // Test minimum + { + auto fun = [](std::vector inputs) { + return std::vector{minimum(inputs[0], inputs[1])}; + }; + auto out = vjp(fun, {array(4.0f), array(2.0f)}, {array(2.0f)}).second; + CHECK_EQ(out[0].item(), 0.0f); + CHECK_EQ(out[1].item(), 2.0f); + + out = vjp(fun, {array(2.0f), array(2.0f)}, {array(1.0f)}).second; + auto out_a = out[0].item(); + auto out_b = out[1].item(); + CHECK( + ((out_a == 1.0f && out_b == 0.0f) || (out_a == 0.0f && out_b == 1.0f))); + } + + // Test logaddexp + { + auto fun = [](std::vector inputs) { + return std::vector{logaddexp(inputs[0], inputs[1])}; + }; + + constexpr auto inf = std::numeric_limits::infinity(); + + auto out = vjp(fun, {array(2.0), array(2.0f)}, {array(1.0f)}).second; + CHECK_EQ(out[0].item(), 0.5f); + CHECK_EQ(out[1].item(), 0.5f); + out = vjp(fun, {array(2.0), array(2.0f)}, {array(2.0f)}).second; + CHECK_EQ(out[0].item(), 1.0f); + CHECK_EQ(out[1].item(), 1.0f); + + out = vjp(fun, {array(inf), array(2.0f)}, {array(1.0f)}).second; + CHECK_EQ(out[0].item(), 1.0f); + CHECK_EQ(out[1].item(), 0.0f); + + out = vjp(fun, {array(-inf), array(2.0f)}, {array(1.0f)}).second; + CHECK_EQ(out[0].item(), 0.0f); + CHECK_EQ(out[1].item(), 1.0f); + + out = vjp(fun, {array(-10.0f), array(-inf)}, {array(1.0f)}).second; + CHECK_EQ(out[0].item(), 1.0f); + CHECK_EQ(out[1].item(), 0.0f); + + out = vjp(fun, {array(-inf), array(-inf)}, {array(1.0f)}).second; + CHECK(std::isnan(out[0].item())); + CHECK(std::isnan(out[1].item())); + } + + // Test power + { + auto fun = [](std::vector inputs) { + return std::vector{inputs[0] ^ inputs[1]}; + }; + auto out = vjp(fun, {array(4.0f), array(3.0f)}, {array(1.0f)}).second; + CHECK_EQ(out[0].item(), 48.0f); + CHECK_EQ(out[1].item(), std::log(4.0f) * 64.0f); + } + + // Test sum + { + std::vector axes; + auto fun = [&axes](array input) { return sum(input, axes); }; + axes = {}; + auto out = vjp(fun, array(2.0f), array(3.0f)).second; + CHECK_EQ(out.item(), 3.0f); + + axes = {0}; + out = vjp(fun, array({}), array(3.0f)).second; + CHECK_EQ(out.size(), 0); + CHECK_EQ(out.shape(), std::vector{0}); + + axes = {0}; + out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2})) + .second; + auto expected = + array({1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}, {2, 2, 2}); + CHECK(array_equal(out, expected).item()); + + axes = {1}; + out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2})) + .second; + expected = + array({1.0f, 2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 3.0f, 4.0f}, {2, 2, 2}); + CHECK(array_equal(out, expected).item()); + + axes = {2}; + out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2})) + .second; + expected = + array({1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f}, {2, 2, 2}); + CHECK(array_equal(out, expected).item()); + } + + // Test prod + { + std::vector axes; + auto fun = [&axes](array input) { return prod(input, axes); }; + axes = {}; + auto out = vjp(fun, array(2.0f), array(3.0f)).second; + CHECK_EQ(out.item(), 3.0f); + + axes = {0}; + out = vjp(fun, + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}), + array( + {1.0f, 2.0f, 3.0f}, + { + 3, + })) + .second; + auto expected = array({4.0f, 10.0f, 18.0f, 1.0f, 4.0f, 9.0f}, {2, 3}); + CHECK(array_equal(out, expected).item()); + + axes = {0, 1}; + out = vjp(fun, + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}), + array(1.0f)) + .second; + expected = array({720.0f, 360.0f, 240.0f, 180.0f, 144.0f, 120.0f}, {2, 3}); + CHECK(array_equal(out, expected).item()); + } +} + +TEST_CASE("test gather and take grads") { + // Check linear takes + auto linear_f = [](array indices) { + auto fun_linear = [&indices](array input) { return take(input, indices); }; + + return fun_linear; + }; + + auto src = ones({4, 4}); + auto ind = array({0, 1, 2, 3}, uint32); + auto out = vjp(linear_f(ind), src, ones({4})).second; + auto out_1 = take(out, array({0}, uint32), 0); + auto out_2 = take(out, array({1, 2, 3}, uint32), 0); + CHECK(array_equal(out_1, ones({1, 4})).item()); + CHECK(array_equal(out_2, zeros({3, 4})).item()); + auto tangent = reshape(arange(16), {4, 4}); + out = jvp(linear_f(ind), src, tangent).second; + CHECK(array_equal(out, array({0, 1, 2, 3})).item()); + + src = ones({4}); + ind = array({0, 0, 0, 0}, uint32); + out = vjp(linear_f(ind), src, ones({4})).second; + out_1 = take(out, array({0}, uint32)); + CHECK_EQ(out_1.item(), 4.0f); + + tangent = arange(4); + out = jvp(linear_f(ind), src, tangent).second; + CHECK(array_equal(out, array({0, 0, 0, 0})).item()); + + // Check axis takes + src = ones({4, 4}); + ind = array({0, 1, 2, 3}, uint32); + + auto fun = [&ind](array input) { return take(input, ind, 0); }; + + out = vjp(fun, src, ones({4, 4})).second; + CHECK(array_equal(out, src).item()); + + out = jvp(fun, src, ones({4, 4})).second; + CHECK(array_equal(out, src).item()); + + // Check index throw + auto fun_throw = [](std::vector inputs) { + return std::vector{take(inputs[0], inputs[1])}; + }; + + CHECK_THROWS_AS( + vjp(fun_throw, {src, ind}, {ones({4, 4})}), std::invalid_argument); + + CHECK_THROWS_AS( + jvp(fun_throw, {src, ind}, {ones({4, 4}), ind}), std::invalid_argument); +} + +TEST_CASE("test slice grads") { + std::vector start = {5, 0, 0}; + std::vector stop = {7, 2, 4}; + std::vector strides = {1, 1, 1}; + + auto fn = [&start, &stop, &strides](array input) { + return slice(input, start, stop, strides); + }; + + auto src = ones({8, 8, 8}); + auto out = vjp(fn, src, ones({2, 2, 4})).second; + CHECK_EQ(sum(out).item(), 16.); + + out = jvp(fn, src, full({8, 8, 8}, 2.0f)).second; + CHECK(array_equal(out, full({2, 2, 4}, 2.0f)).item()); + + src = ones({4, 4}); + start = {2, 0}; + stop = {4, 4}; + strides = {1, 1}; + out = vjp(fn, src, ones({2, 4})).second; + auto out_1 = take(out, array({0, 1}, uint32), 0); + auto out_2 = take(out, array({2, 3}, uint32), 0); + + CHECK(array_equal(out_1, zeros({2, 4})).item()); + CHECK(array_equal(out_2, ones({2, 4})).item()); + + start = {0, 0}; + stop = {4, 4}; + strides = {2, 2}; + auto cotangent = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + out = vjp(fn, src, cotangent).second; + auto expected = astype( + array({1, 0, 2, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0}, {4, 4}), float32); + CHECK(array_equal(out, expected).item()); + + out = jvp(fn, src, ones({4, 4})).second; + CHECK(array_equal(out, ones({2, 2})).item()); + + // Empty slices. + start = {0, 0}; + stop = {0, 4}; + cotangent = reshape(array({}), {0, 2}); + out = vjp(fn, src, cotangent).second; + CHECK(array_equal(out, zeros({4, 4})).item()); + + out = jvp(fn, src, ones({4, 4})).second; + CHECK_EQ(out.size(), 0); +} + +TEST_CASE("test min and max vjp") { + // Test min + { + std::vector axes; + array in({}); + array v({}); + array expected({}); + array out({}); + auto fun = [&axes](array input) { return min(input, axes); }; + + axes = {}; + in = array({2.0f}); + out = vjp(fun, array(2.0f), array(3.0f)).second; + CHECK_EQ(out.item(), 3.0f); + + axes = {0}; + in = reshape(array({1.0f, 2.0f, 2.0f, -1.0f}), {2, 2}); + v = array({3.0f, 7.0f}); + out = vjp(fun, in, v).second; + expected = array({3.0f, 0.0f, 0.0f, 7.0f}); + expected = reshape(expected, {2, 2}); + CHECK(array_equal(out, expected).item()); + + axes = {0, 2}; + in = reshape( + array({1.0f, 0.0f, 0.0f, 1.0f, -1.0f, -1.0f, 1.0f, 0.0f}), {2, 2, 2}); + v = array({3.0f, 7.0f}); + out = vjp(fun, in, v).second; + expected = array({0.0f, 0.0f, 3.5f, 0.0f, 1.5f, 1.5f, 0.0f, 3.5f}); + expected = reshape(expected, {2, 2, 2}); + CHECK(array_equal(out, expected).item()); + } + + // Test max + { + std::vector axes; + array in({}); + array v({}); + array expected({}); + array out({}); + auto fun = [&axes](array input) { return max(input, axes); }; + + axes = {}; + in = array({2.0f}); + out = vjp(fun, array(2.0f), array(3.0f)).second; + CHECK_EQ(out.item(), 3.0f); + + axes = {0}; + in = reshape(array({1.0f, 2.0f, 2.0f, -1.0f}), {2, 2}); + v = array({3.0f, 7.0f}); + out = vjp(fun, in, v).second; + expected = array({0.0f, 7.0f, 3.0f, 0.0f}); + expected = reshape(expected, {2, 2}); + CHECK(array_equal(out, expected).item()); + + axes = {0, 2}; + in = reshape( + array({1.0f, 0.0f, 0.0f, 1.0f, -1.0f, -1.0f, 1.0f, 0.0f}), {2, 2, 2}); + v = array({3.0f, 7.0f}); + out = vjp(fun, in, v).second; + expected = array({3.0f, 0.0f, 0.0f, 3.5f, 0.0f, 0.0f, 3.5f, 0.0f}); + expected = reshape(expected, {2, 2, 2}); + CHECK(array_equal(out, expected).item()); + } +} + +TEST_CASE("test reshape and transpose grads") { + { + auto fn = [](array a) { return reshape(a, {3, 4}); }; + + auto out = vjp(fn, ones({12}), full({3, 4}, 2.0f)).second; + CHECK(array_equal(out, full({12}, 2.0f)).item()); + + out = jvp(fn, ones({12}), full({12}, 2.0f)).second; + CHECK(array_equal(out, full({3, 4}, 2.0f)).item()); + } + + { + auto fn = [](array a) { return transpose(a, {1, 2, 0}); }; + + auto cotan = reshape(arange(24), {3, 4, 2}); + auto out = vjp(fn, ones({2, 3, 4}), cotan).second; + CHECK(array_equal(out, transpose(cotan, {2, 0, 1})).item()); + + auto tangent = reshape(arange(24), {2, 3, 4}); + out = jvp(fn, ones({2, 3, 4}), tangent).second; + CHECK(array_equal(out, transpose(tangent, {1, 2, 0})).item()); + } +} + +TEST_CASE("test copy grads") { + auto fn = [](array a) { return copy(a); }; + + auto cotan = arange(4, float32); + auto out = vjp(fn, ones({4}), cotan).second; + CHECK(array_equal(out, arange(4, float32)).item()); + + auto tangent = arange(4, float32); + out = jvp(fn, ones({4}), tangent).second; + CHECK(array_equal(out, tangent).item()); +} + +TEST_CASE("test matmul vjp") { + auto fun = [](std::vector inputs) { + return std::vector{matmul(inputs[0], inputs[1])}; + }; + + auto a = array({1.0f, 2.0f}, {1, 2}); + auto b = array({3.0f, 4.0f}, {2, 1}); + auto out = vjp(fun, {a, b}, {array({2.0f}, {1, 1})}).second; + + CHECK(array_equal(out[0], array({6.0f, 8.0f}, {1, 2})).item()); + CHECK(array_equal(out[1], array({2.0f, 4.0f}, {2, 1})).item()); + + a = array({1.0f, 2.0f}, {2, 1}); + b = array({3.0f, 4.0f}, {1, 2}); + out = vjp(fun, {a, b}, {array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2})}).second; + CHECK(array_equal(out[0], array({11.0f, 25.0f}, {2, 1})).item()); + CHECK(array_equal(out[1], array({7.0f, 10.0f}, {1, 2})).item()); + + a = array({1.0f, 2.0f, 1.0f, 2.0f}, {2, 2, 1}); + b = array({1.0f, 1.0f, 2.0f, 2.0f}, {2, 1, 2}); + auto vjps = vjp(fun, {a, b}, {ones({2, 2, 2})}).second; + auto vjpx = array({2.0f, 2.0f, 4.0f, 4.0f}, {2, 2, 1}); + auto vjpy = array({3.0f, 3.0f, 3.0f, 3.0f}, {2, 1, 2}); + CHECK(array_equal(vjps[0], vjpx).item()); + CHECK(array_equal(vjps[1], vjpy).item()); +} + +TEST_CASE("test concatenate grads") { + auto arrs = split(arange(5, float32), 5); + eval(arrs); + + auto fn = [&arrs](const std::vector& inputs) { + arrs[2] = inputs[0]; + arrs[4] = inputs[1]; + return std::vector{concatenate(arrs, 0)}; + }; + auto out = vjp(fn, {arrs[2], arrs[4]}, {arange(5, float32)}).second; + + CHECK_EQ(out.size(), 2); + CHECK_EQ(out[0].item(), 2.0f); + CHECK_EQ(out[1].item(), 4.0f); + + out = jvp(fn, {arrs[2], arrs[4]}, {array({2.0f}, {1}), array({3.0f}, {1})}) + .second; + CHECK_EQ(out.size(), 1); + CHECK( + array_equal(out[0], array({0.0f, 0.0f, 2.0f, 0.0f, 3.0f})).item()); +} + +TEST_CASE("test comparison grads") { + auto x = ones({3, 1}); + auto y = zeros({1, 3}); + + auto check_vjp_jvp = [&x, &y](auto fn) { + auto fn_wrap = [&fn](std::vector inputs) { + return std::vector{fn(inputs[0], inputs[1], default_device())}; + }; + auto out_shape = broadcast_shapes(x.shape(), y.shape()); + std::vector vjps = vjp(fn_wrap, {x, y}, {ones(out_shape)}).second; + bool correct = array_equal(vjps[0], zeros(x.shape())).item(); + correct &= array_equal(vjps[1], zeros(y.shape())).item(); + + std::vector jvps = + jvp(fn_wrap, {x, y}, {ones(x.shape()), ones(y.shape())}).second; + correct &= array_equal(jvps[0], zeros(out_shape)).item(); + return correct; + }; + + CHECK(check_vjp_jvp(equal)); + CHECK(check_vjp_jvp(greater)); + CHECK(check_vjp_jvp(less)); + CHECK(check_vjp_jvp(greater_equal)); + CHECK(check_vjp_jvp(less_equal)); +} + +TEST_CASE("test as_strided grads") { + auto x = ones({11}); + std::vector shape = {5, 5}; + std::vector strides = {1, 1}; + size_t offset = 0; + + auto fun = [&shape, &strides, &offset](array x) { + return as_strided(x, shape, strides, offset); + }; + + auto out = vjp(fun, x, ones(shape)).second; + auto expected = array({1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 0}); + CHECK(array_equal(out, expected).item()); + + offset = 1; + out = vjp(fun, x, ones(shape)).second; + expected = array({0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0}); + CHECK(array_equal(out, expected).item()); + + offset = 3; + shape = {3, 3}; + strides = {0, 1}; + out = vjp(fun, x, ones(shape)).second; + expected = array({0, 0, 0, 3, 3, 3, 0, 0, 0, 0, 0}); + CHECK(array_equal(out, expected).item()); + + offset = 3; + shape = {3, 3}; + strides = {0, 1}; + out = vjp(fun, x, reshape(astype(arange(9), x.dtype()), {3, 3})).second; + expected = array({0, 0, 0, 9, 12, 15, 0, 0, 0, 0, 0}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test jvp from vjp") { + // Unary elementwise ops + { + auto x = random::uniform({5, 10}); + eval(x); + + auto compute_derivs = [&x](auto fn) { + auto fn_wrap = [&fn](array input) { return fn(input, default_device()); }; + + // Compute vjp + array vjp_out = vjp(fn_wrap, x, ones(x.shape())).second; + + // Compute jvp + array jvp_out = jvp(fn_wrap, x, ones(x.shape())).second; + + return array_equal(vjp_out, jvp_out).item(); + }; + + CHECK(compute_derivs(mlx::core::abs)); + CHECK(compute_derivs(mlx::core::cos)); + CHECK(compute_derivs(mlx::core::erf)); + CHECK(compute_derivs(mlx::core::erfinv)); + CHECK(compute_derivs(mlx::core::exp)); + CHECK(compute_derivs(mlx::core::log)); + CHECK(compute_derivs(mlx::core::log1p)); + CHECK(compute_derivs(mlx::core::negative)); + CHECK(compute_derivs(mlx::core::sigmoid)); + CHECK(compute_derivs(mlx::core::sign)); + CHECK(compute_derivs(mlx::core::sin)); + CHECK(compute_derivs(mlx::core::square)); + CHECK(compute_derivs(mlx::core::sqrt)); + CHECK(compute_derivs(mlx::core::rsqrt)); + } + + // Binary elementwise ops + { + auto x = random::uniform({5, 10}); + auto y = random::uniform({5, 10}); + eval(x, y); + + auto compute_derivs = [&x, &y](auto fn) { + auto fn_wrap = [&fn](std::vector inputs) { + return std::vector{fn(inputs[0], inputs[1], default_device())}; + }; + + // Compute vjp and add results + auto vjps = vjp(fn_wrap, {x, y}, {ones(x.shape())}).second; + array vjp_out = add(vjps[0], vjps[1]); + + // Compute jvp + array jvp_out = + jvp(fn_wrap, {x, y}, {ones(x.shape()), ones(y.shape())}).second[0]; + return array_equal(vjp_out, jvp_out).item(); + }; + + CHECK(compute_derivs(add)); + CHECK(compute_derivs(divide)); + CHECK(compute_derivs(logaddexp)); + CHECK(compute_derivs(maximum)); + CHECK(compute_derivs(minimum)); + CHECK(compute_derivs(multiply)); + CHECK(compute_derivs(subtract)); + CHECK(compute_derivs(power)); + } +} + +TEST_CASE("test complex gradients") { + { + auto add_fn = [](std::vector inputs) { + return std::vector{add(inputs[0], inputs[1], default_device())}; + }; + + // Compute jvp + auto x = array(complex64_t{1.0, 1.0}); + auto y = array(complex64_t{1.0, 1.0}); + auto x_tan = array(complex64_t{1.0, 2.0}); + auto y_tan = array(complex64_t{2.0, 1.0}); + auto jvp_out = jvp(add_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ(jvp_out[0].item(), complex64_t{3.0, 3.0}); + + // Compute vjp + auto cotan = array(complex64_t{3.0, 3.0}); + auto vjp_out = vjp(add_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].item(), complex64_t{3.0, 3.0}); + CHECK_EQ(vjp_out[1].item(), complex64_t{3.0, 3.0}); + } + + { + // Compute jvp + auto x = array(complex64_t{2.0, 4.0}); + auto y = array(3.0f); + + auto x_tan = array(complex64_t{1.0, 2.0}); + auto y_tan = array(2.0f); + + auto out = jvp([x](array a) { return multiply(a, x); }, y, y_tan).second; + CHECK_EQ(out.item(), complex64_t{4.0, 8.0}); + + out = jvp([y](array a) { return multiply(a, y); }, x, x_tan).second; + CHECK_EQ(out.item(), complex64_t{3.0, 6.0}); + + auto cotan = array(complex64_t{2.0, 3.0}); + out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second; + CHECK_EQ(out.dtype(), float32); + CHECK_EQ(out.item(), -8.0); + + out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second; + CHECK_EQ(out.item(), complex64_t{6.0, 9.0}); + } +} + +TEST_CASE("test scan grads") { + // Test cumsum + { + int axis = 0; + int reverse = false; + int inclusive = true; + auto fun = [&axis, &reverse, &inclusive](array x) { + return cumsum(x, axis, reverse, inclusive); + }; + + auto out = vjp(fun, ones({4}), ones({4})).second; + auto expected = array({4.0f, 3.0f, 2.0f, 1.0f}, {4}); + CHECK(array_equal(out, expected).item()); + + reverse = true; + out = vjp(fun, ones({4}), ones({4})).second; + expected = array({1.0f, 2.0f, 3.0f, 4.0f}, {4}); + CHECK(array_equal(out, expected).item()); + + reverse = true; + inclusive = false; + out = vjp(fun, ones({4}), ones({4})).second; + expected = array({0.0f, 1.0f, 2.0f, 3.0f}, {4}); + CHECK(array_equal(out, expected).item()); + + reverse = false; + inclusive = false; + out = vjp(fun, ones({4}), ones({4})).second; + expected = array({3.0f, 2.0f, 1.0f, 0.0f}, {4}); + CHECK(array_equal(out, expected).item()); + } + + // Test cumprod + { + int axis = 0; + int reverse = false; + int inclusive = true; + auto fun = [&axis, &reverse, &inclusive](array x) { + return cumprod(x, axis, reverse, inclusive); + }; + + auto x = array({1.0f, 2.0f, 3.0f, 4.0f}, {4}); + auto g = array({1.0f, 2.0f, 3.0f, 4.0f}, {4}); + auto out = vjp(fun, x, g).second; + auto expected = array({119.0f, 59.0f, 38.0f, 24.0f}, {4}); + CHECK(allclose(out, expected).item()); + + reverse = true; + out = vjp(fun, x, g).second; + expected = array({24.0f, 36.0f, 36.0f, 31.0f}, {4}); + CHECK(array_equal(out, expected).item()); + + inclusive = false; + out = vjp(fun, x, g).second; + expected = array({0.0f, 12.0f, 16.0f, 15.0f}, {4}); + CHECK(array_equal(out, expected).item()); + + reverse = false; + out = vjp(fun, x, g).second; + expected = array({32.0f, 15.0f, 8.0f, 0.0f}, {4}); + CHECK(array_equal(out, expected).item()); + } + + // Test cumsum jvp + { + int axis = 0; + int reverse = false; + int inclusive = true; + auto fun = [&axis, &reverse, &inclusive](array x) { + return cumsum(x, axis, reverse, inclusive); + }; + + auto x = array({1.0f, 2.0f, 3.0f, 4.0f}, {4}); + auto out = jvp(fun, x, ones({4})).second; + auto expected = array({1.0f, 2.0f, 3.0f, 4.0f}, {4}); + CHECK(array_equal(out, expected).item()); + + reverse = true; + out = jvp(fun, x, ones({4})).second; + expected = array({4.0f, 3.0f, 2.0f, 1.0f}, {4}); + CHECK(array_equal(out, expected).item()); + + inclusive = false; + out = jvp(fun, x, ones({4})).second; + expected = array({3.0f, 2.0f, 1.0f, 0.0f}, {4}); + CHECK(array_equal(out, expected).item()); + + reverse = false; + out = jvp(fun, x, ones({4})).second; + expected = array({0.0f, 1.0f, 2.0f, 3.0f}, {4}); + CHECK(array_equal(out, expected).item()); + } +} diff --git a/tests/device_tests.cpp b/tests/device_tests.cpp new file mode 100644 index 000000000..dedf64571 --- /dev/null +++ b/tests/device_tests.cpp @@ -0,0 +1,33 @@ +#include "doctest/doctest.h" + +#include + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test device placement") { + auto device = default_device(); + Device d = metal::is_available() ? Device::gpu : Device::cpu; + if (std::getenv("DEVICE") == nullptr) { + CHECK_EQ(device, d); + } + + array x(1.0f); + array y(1.0f); + auto z = add(x, y, default_device()); + if (metal::is_available()) { + z = add(x, y, Device::gpu); + z = add(x, y, Device(Device::gpu, 0)); + } else { + CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument); + CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument); + } + + // Set the default device to the CPU + set_default_device(Device::cpu); + CHECK_EQ(default_device(), Device::cpu); + + // Revert + set_default_device(device); +} diff --git a/tests/eval_tests.cpp b/tests/eval_tests.cpp new file mode 100644 index 000000000..88af76d58 --- /dev/null +++ b/tests/eval_tests.cpp @@ -0,0 +1,97 @@ +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test eval") { + { + array x(1.0); + array y(1); + array z(true); + eval({x, y, z}); + CHECK_EQ(x.item(), 1.0); + } + + { + array x(1.0); + array y = ones({2, 2}); + array z(true); + eval({x, y, z}); + CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item()); + } +} + +TEST_CASE("test eval multiple") { + auto x = ones({10, 10}); + auto y = ones({10, 10}); + eval({x, y}); + CHECK(array_equal(x, y).item()); + + auto a = x + y; + auto b = x - y; + eval({a, b}); + CHECK(array_equal(a, full({10, 10}, 2.0f)).item()); + CHECK(array_equal(b, full({10, 10}, 0.0f)).item()); + + x = ones({10, 10}); + y = ones({10, 10}); + eval(x, y); + CHECK(array_equal(x, y).item()); + + a = x + y; + b = x - y; + eval(a, b); + CHECK(array_equal(a, full({10, 10}, 2.0f)).item()); + CHECK(array_equal(b, full({10, 10}, 0.0f)).item()); +} + +TEST_CASE("test eval with tracer") { + auto x = array(1); + x.set_tracer(true); + + // Ok, x is not a node + eval(x); + + x = ones({2, 3}); + x.set_tracer(true); + CHECK_THROWS(eval(x)); + + // Ok retain_graph=true + eval({x}, true); + + // Make sure all arguments are checked + auto y = ones({2, 3}); + CHECK_THROWS(eval(x, y)); +} + +TEST_CASE("test eval graph retention") { + auto x = array(1); + auto y = array(2); + auto z = x + y; + eval({z}, true); + CHECK(z.has_primitive()); + CHECK(z.is_evaled()); + CHECK_EQ(z.item(true), 3); + CHECK(z.has_primitive()); + CHECK(z.is_evaled()); + + CHECK_EQ(z.item(), 3); + CHECK(!z.has_primitive()); + CHECK(z.is_evaled()); + + z = x + y; + auto a = z + x; + auto b = a + y; + eval({b}, true); + CHECK(z.has_primitive()); + CHECK(z.is_evaled()); + CHECK(a.has_primitive()); + CHECK(a.is_evaled()); + + eval({b}, false); + CHECK(!z.has_primitive()); + CHECK(z.is_evaled()); + CHECK(!a.has_primitive()); + CHECK(a.is_evaled()); +} diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp new file mode 100644 index 000000000..45f36e473 --- /dev/null +++ b/tests/load_tests.cpp @@ -0,0 +1,81 @@ +#include +#include +#include + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +std::string get_temp_file(const std::string& name) { + return std::filesystem::temp_directory_path().append(name); +} + +TEST_CASE("test single array serialization") { + // Basic test + { + auto a = random::uniform(-5.f, 5.f, {2, 5, 12}, float32); + + std::string file_path = get_temp_file("test_arr.npy"); + + save(file_path, a); + auto b = load(file_path); + + CHECK_EQ(a.dtype(), b.dtype()); + CHECK_EQ(a.shape(), b.shape()); + CHECK(array_equal(a, b).item()); + } + + // Other shapes + { + auto a = random::uniform( + -5.f, + 5.f, + { + 1, + }, + float32); + + std::string file_path = get_temp_file("test_arr_0.npy"); + + save(file_path, a); + auto b = load(file_path); + + CHECK_EQ(a.dtype(), b.dtype()); + CHECK_EQ(a.shape(), b.shape()); + CHECK(array_equal(a, b).item()); + } + + { + auto a = random::uniform( + -5.f, + 5.f, + { + 46, + }, + float32); + + std::string file_path = get_temp_file("test_arr_1.npy"); + + save(file_path, a); + auto b = load(file_path); + + CHECK_EQ(a.dtype(), b.dtype()); + CHECK_EQ(a.shape(), b.shape()); + CHECK(array_equal(a, b).item()); + } + + { + auto a = random::uniform(-5.f, 5.f, {5, 2, 1, 3, 4}, float32); + + std::string file_path = get_temp_file("test_arr_2.npy"); + + save(file_path, a); + auto b = load(file_path); + + CHECK_EQ(a.dtype(), b.dtype()); + CHECK_EQ(a.shape(), b.shape()); + CHECK(array_equal(a, b).item()); + } +} diff --git a/tests/scheduler_tests.cpp b/tests/scheduler_tests.cpp new file mode 100644 index 000000000..781b2aa70 --- /dev/null +++ b/tests/scheduler_tests.cpp @@ -0,0 +1,119 @@ +#include "doctest/doctest.h" + +#include "mlx/mlx.h" +#include "mlx/scheduler.h" + +using namespace mlx::core; + +TEST_CASE("test stream management") { + auto s1 = default_stream(default_device()); + CHECK_EQ(s1.device, default_device()); + + auto s2 = new_stream(default_device()); + CHECK_EQ(s2.device, default_device()); + CHECK_NE(s1, s2); + + // Check that default streams have the correct devices + if (metal::is_available()) { + auto s_gpu = default_stream(Device::gpu); + CHECK_EQ(s_gpu.device, Device::gpu); + } else { + CHECK_THROWS_AS(default_stream(Device::gpu), std::invalid_argument); + } + auto s_cpu = default_stream(Device::cpu); + CHECK_EQ(s_cpu.device, Device::cpu); + + s_cpu = new_stream(Device::cpu); + CHECK_EQ(s_cpu.device, Device::cpu); + + if (metal::is_available()) { + auto s_gpu = new_stream(Device::gpu); + CHECK_EQ(s_gpu.device, Device::gpu); + } else { + CHECK_THROWS_AS(new_stream(Device::gpu), std::invalid_argument); + } +} + +TEST_CASE("test asynchronous launch") { + auto s1 = default_stream(default_device()); + auto s2 = new_stream(default_device()); + + // Make sure streams execute asynchronously + int x = 1; + auto p1 = std::make_shared>(); + auto p2 = std::make_shared>(); + auto f1 = p1->get_future().share(); + auto f2 = p2->get_future().share(); + auto fn1 = [&x, p = std::move(p1)]() { + x++; + p->set_value(); + }; + auto fn2 = [&x, p = std::move(p2), f = std::move(f1)]() { + f.wait(); + x *= 5; + p->set_value(); + }; + + // fn2 is launched first and is waiting on fn1 but since + // they are on different streams there is no deadlock. + scheduler::enqueue(s2, std::move(fn2)); + scheduler::enqueue(s1, std::move(fn1)); + + f2.wait(); + + CHECK_EQ(x, 10); +} + +TEST_CASE("test stream placement") { + auto s1 = default_stream(default_device()); + auto s2 = new_stream(default_device()); + + { + // Wait on stream 1 + auto p = std::make_shared>(); + auto f = p->get_future().share(); + scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); }); + + // Do some work on stream 2 + auto x = zeros({100}, float32, s2); + auto y = ones({100}, float32, s2); + auto z = add(x, y, s2); + eval(z); + p->set_value(); + } + + { + // Wait on stream 1 + auto p = std::make_shared>(); + auto f = p->get_future().share(); + scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); }); + + // Do some work on stream 2 + auto fn = [&s2](array a) { return add(a, add(a, a, s2), s2); }; + auto x = zeros({100}, s2); + + // The whole vjp computation should happen + // on the second stream otherwise this will hang. + auto [out, dout] = vjp(fn, x, ones({100}, s2)); + + // The whole jvp computation should happen on the + // second stream. + std::tie(out, dout) = jvp(fn, x, ones({100}, s2)); + eval(out, dout); + + p->set_value(); + } +} + +TEST_CASE("test scheduler races") { + auto x = zeros({1}); + auto y = zeros({100}); + eval(x, y); + auto a = exp(x); + eval(a); + a = exp(x); + for (int i = 0; i < 10000; ++i) { + y = exp(y); + } + eval(a, y); +} diff --git a/tests/utils_tests.cpp b/tests/utils_tests.cpp new file mode 100644 index 000000000..867a2cf58 --- /dev/null +++ b/tests/utils_tests.cpp @@ -0,0 +1,26 @@ + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test type promotion") { + for (auto t : {bool_, uint32, int32, int64, float32}) { + auto a = array(0, t); + CHECK_EQ(result_type({a}), t); + + std::vector arrs = {array(0, t), array(0, t)}; + CHECK_EQ(result_type(arrs), t); + } + + { + std::vector arrs = {array(false), array(0, int32)}; + CHECK_EQ(result_type(arrs), int32); + } + + { + std::vector arrs = {array(0, int32), array(false), array(0.0f)}; + CHECK_EQ(result_type(arrs), float32); + } +}