awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

View File

@@ -0,0 +1,198 @@
#include <iostream>
#include <sstream>
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
void time_irregular_binary_ops_1D() {
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", add, a, b, device);
}
void time_irregular_binary_ops_2D() {
auto device = default_device();
int size = 2048;
auto a = random::uniform({size, size});
auto b = random::uniform({size, size});
eval(a, b);
TIMEM("2D regular", add, a, b, device);
b = transpose(b);
eval(b);
TIMEM("2D transpose", add, a, b, device);
b = random::uniform({size});
eval(b);
TIMEM("2D broadcast dim 0", add, a, b, device);
b = reshape(b, {size, 1});
eval(b);
TIMEM("2D broadcast dim 1", add, a, b, device);
}
void time_irregular_binary_ops_3D() {
auto device = default_device();
int d0 = 32;
int d1 = 512;
int d2 = 512;
auto a = random::uniform({d0, d1, d2});
auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", add, a, b, device);
b = transpose(b, {0, 2, 1});
TIMEM("3D transpose", add, a, b, device);
b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", add, a, b, device);
b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", add, a, b, device);
b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", add, a, b, device);
b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
}
void time_irregular_binary_ops_4D() {
auto device = default_device();
std::vector<int> shape = {8, 8, 512, 512};
auto a = random::uniform(shape);
auto b = random::uniform(shape);
TIMEM("4D regular", add, a, b, device);
b = transpose(b, {0, 1, 3, 2});
TIMEM("4D transpose", add, a, b, device);
std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1;
b = random::uniform(shape);
std::ostringstream msg;
msg << om << i;
TIMEM(msg.str(), add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1;
std::ostringstream msg;
msg << om << i << ", " << j;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1;
std::ostringstream msg;
msg << om << i << ", " << j << ", " << k;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[k] = a.shape(k);
}
}
shape[i] = a.shape(i);
}
}
void time_irregular_reshape() {
auto device = default_device();
std::vector<int> shape;
auto reshape_fn = [&shape, device](const array& a) {
return reshape(a, shape, device);
};
int size = 64;
int d = 2 * size;
auto a = random::uniform({d, d, d});
shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a);
a = transpose(a);
shape = {8 * size, size, size};
TIMEM("3D transpose", reshape_fn, a);
a = transpose(a, {1, 2, 0});
shape = {8 * size, size, size};
TIMEM("3D transpose dims 1 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a);
a = broadcast_to(random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
}
void time_irregular_astype_1D() {
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", astype, a, int32, device);
}
void time_irregular_astype_2D() {
auto device = default_device();
int size = 2048;
std::vector<int> shape = {size, size};
auto a = random::uniform(shape);
TIMEM("2D regular", astype, a, int32, device);
a = transpose(a);
TIMEM("2D transpose", astype, a, int32, device);
a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", astype, a, int32, device);
a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", astype, a, int32, device);
}
int main(int argc, char** argv) {
if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? Device::gpu : Device::cpu);
}
std::cout << "Benchmarks for " << default_device() << std::endl;
time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D();
time_irregular_binary_ops_4D();
time_irregular_reshape();
time_irregular_astype_1D();
time_irregular_astype_2D();
}

View File

@@ -0,0 +1,247 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
void time_creation_ops() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto full_fp32 = [&]() { return full(shape, 3.3f); };
TIME(full_fp32);
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
TIME(zeros_fp32);
auto ones_fp32 = [&]() { return ones(shape, float32); };
TIME(ones_fp32);
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32);
}
void time_type_conversions() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto device = default_device();
auto a = zeros(shape, float32);
eval(a);
TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("float32 to uint32", astype, a, uint32, device);
a = zeros(shape, int32);
eval(a);
TIMEM("int32 to float32", astype, a, float32, device);
a = zeros(shape, bool_);
eval(a);
TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to uint32", astype, a, uint32, device);
}
void time_random_generation() {
int M = 2000;
int N = 500;
auto uniform = [&]() { return random::uniform({M, N}, float32); };
TIME(uniform);
auto normal = [&]() { return random::normal({M, N}, float32); };
TIME(normal);
}
void time_unary_ops() {
int M = 2000;
int N = 500;
auto device = default_device();
auto a = random::normal({M, N});
eval(a);
TIME(mlx::core::abs, a, device);
TIME(negative, a, device);
TIME(sign, a, device);
TIME(square, a, device);
TIME(mlx::core::sqrt, a, device);
TIME(rsqrt, a, device);
TIME(mlx::core::exp, a, device);
a = random::uniform({M, N});
TIME(mlx::core::log, a, device);
}
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(add, a, b, device);
TIME(subtract, a, b, device);
TIME(multiply, a, b, device);
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
}
void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50;
auto a = random::uniform({M, N, O, P});
auto b = random::uniform({M, N, O, P});
auto device = default_device();
eval(a, b);
TIMEM("non-strided", add, a, b, device);
a = transpose(a, {1, 0, 2, 3});
b = transpose(b, {3, 2, 0, 1});
eval(a, b);
TIMEM("strided", add, a, b, device);
}
void time_comparisons() {
int M = 1000, N = 100, K = 10;
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(equal, a, b, device);
TIME(greater, a, b, device);
TIME(greater_equal, a, b, device);
TIME(less, a, b, device);
TIME(less_equal, a, b, device);
}
void time_matvec() {
int M = 2000, N = 200;
auto a = random::uniform({M, N});
auto b = random::uniform({N});
auto c = random::uniform({M});
eval(a, b, c);
auto matvec = [&]() { return matmul(a, b); };
TIME(matvec);
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
TIME(matvec_transpose);
}
void time_matmul() {
int M = 1000, N = 1000, K = 1000;
auto a = random::uniform({M, K});
auto b = random::uniform({K, N});
auto device = default_device();
eval(a, b);
TIME(matmul, a, b, device);
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
TIME(transpose_matmul);
}
void time_reductions() {
auto a = random::normal({10000, 1000});
eval(a);
auto sum_all = [&a]() { return sum(a, false); };
TIME(sum_all);
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
TIME(sum_along_0);
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
TIME(sum_along_1);
auto prod_all = [&a]() { return prod(a, false); };
TIME(prod_all);
auto all_true = [&a]() { return all(a, false); };
TIME(all_true);
auto all_along_0 = [&a]() { return all(a, 0, false); };
TIME(all_along_0);
auto all_along_1 = [&a]() { return all(a, 1, false); };
TIME(all_along_1);
auto any_true = [&a]() { return any(a, false); };
TIME(any_true);
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
TIME(argmin_along_1);
}
void time_gather_scatter() {
auto a = random::normal({1000, 768});
eval(a);
auto indices = random::randint(0, 1000, {256});
eval(indices);
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
TIME(embedding_lookup);
indices = random::randint(0, 768 * 1000, {256 * 768});
eval(indices);
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
TIME(single_element_lookup);
indices = random::randint(0, 1000, {256});
auto updates = random::normal({256, 1, 768});
eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
};
TIME(embedding_update);
auto embedding_add = [&a, &indices, &updates]() {
return scatter_add(a, indices, updates, 0);
};
TIME(embedding_add);
a = reshape(a, {-1});
indices = random::randint(0, 768 * 1000, {768 * 256});
updates = random::normal({256 * 768, 1});
eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
};
TIME(single_element_update);
auto single_element_add = [&a, &indices, &updates]() {
return scatter_add(a, indices, updates, 0);
};
TIME(single_element_add);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops();
time_type_conversions();
time_unary_ops();
time_binary_ops();
time_strided_ops();
time_random_generation();
time_comparisons();
time_matvec();
time_matmul();
time_reductions();
time_gather_scatter();
}

View File

@@ -0,0 +1,15 @@
Microbenchmarks comparing MLX to PyTorch
========================================
Implement the same microbenchmarks in MLX and PyTorch to compare and make a
list of the biggest possible performance improvements and/or regressions.
Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for
instance to measure the times it takes to sum across the 3rd axis of the above
tensor on the cpu.
`compare.py` runs several benchmarks and compares the speed-up or lack thereof
in comparison to PyTorch.
Each bench script can be run with `--print-pid` to print the PID and wait for a
key in order to ease attaching a debugger.

View File

@@ -0,0 +1,313 @@
import argparse
import math
import os
import time
import mlx.core as mx
def int_or_list(x):
try:
return int(x)
except ValueError:
return [int(xi) for xi in x.split(",")]
def none_or_list(x):
if x == "":
return None
else:
return [int(xi) for xi in x.split(",")]
def bench(f, *args):
for i in range(10):
f(*args)
s = time.time()
for i in range(100):
f(*args)
e = time.time()
return e - s
def matmul_square(x):
y = x
for i in range(10):
y = y @ x
mx.eval(y)
return y
def matmul(x, y):
ys = []
for i in range(10):
ys.append(x @ y)
mx.eval(ys)
def conv1d(x, y):
ys = []
for i in range(10):
ys.append(mx.conv1d(x, y))
mx.eval(ys)
def conv2d(x, y):
ys = []
for i in range(10):
ys.append(mx.conv2d(x, y))
mx.eval(ys)
def binary(op, x, y):
for i in range(100):
y = getattr(mx, op)(x, y)
mx.eval(y)
def reduction(op, axis, x):
ys = []
for i in range(100):
ys.append(getattr(mx, op)(x, axis=axis))
mx.eval(ys)
def softmax(axis, x):
ys = []
for i in range(100):
ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True))
y = ex / mx.sum(ex, axis=axis, keepdims=True)
ys.append(y)
mx.eval(ys)
def softmax_fused(axis, x):
ys = []
for i in range(100):
y = mx.softmax(x, axis=axis)
ys.append(y)
mx.eval(ys)
def relu(x):
y = x
for i in range(100):
y = mx.maximum(y, 0)
mx.eval(y)
def scalar_mult(x):
y = x
for i in range(100):
y = y * (1.0 / (1 + i))
mx.eval(y)
def cross_entropy(targets, x):
ys = []
for i in range(100):
y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis(
x, mx.reshape(targets, (-1, 1)), axis=-1
)
ys.append(mx.mean(y))
mx.eval(ys)
def logsumexp(axis, x):
ys = []
for i in range(100):
ys.append(mx.logsumexp(x, axis=axis))
mx.eval(ys)
def linear(w, b, x):
ys = []
for i in range(10):
ys.append(x @ mx.transpose(w, (1, 0)) + b)
mx.eval(ys)
def rope(x):
*_, N, D = x.shape
ys = []
for i in range(10):
shape = x.shape
x = mx.reshape(x, (-1, N, D))
positions = mx.arange(N)
freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1)))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta = mx.cos(theta)
sintheta = mx.sin(theta)
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
y = mx.reshape(y, (-1, N, D))
ys.append(y)
mx.eval(ys)
def concatenate(axis, x, y):
ys = []
for i in range(10):
ys.append(mx.concatenate([x, y], axis=axis))
mx.eval(ys)
def cumsum(axis, x):
ys = []
for i in range(10):
ys.append(mx.cumsum(x, axis))
mx.eval(ys)
def sort(axis, x):
ys = []
for i in range(10):
ys.append(mx.sort(x, axis))
mx.eval(ys)
def topk(axis, x):
k = x.shape[axis] // 3
ys = []
for i in range(10):
ys.append(mx.topk(x, k, axis))
mx.eval(ys)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
parser.add_argument(
"--size",
default=[(1024, 1024)],
type=lambda x: list(map(int, x.split("x"))),
help="Set the matrix size",
action="append",
)
parser.add_argument(
"--axis",
default=[1],
type=int_or_list,
help="Set a reduction axis",
action="append",
)
parser.add_argument(
"--transpose",
type=none_or_list,
default=[],
help="Permute the matrix",
action="append",
)
parser.add_argument(
"--print-pid", action="store_true", help="Print the PID and pause"
)
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument(
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
)
args = parser.parse_args()
if len(args.size) > 1:
args.size.pop(0)
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.cpu:
mx.set_default_device(mx.cpu)
else:
mx.set_default_device(mx.gpu)
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
args.dtype
]
xs = []
for size in args.size:
xs.append(mx.random.normal(size).astype(dtype))
for i, t in enumerate(args.transpose):
if t is None:
continue
xs[i] = mx.transpose(xs[i], t)
mx.eval(xs)
x = xs[0]
axis = args.axis[0]
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark == "linear":
print(bench(linear, *xs))
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))
elif args.benchmark == "sum_all":
print(bench(reduction, "sum", None, x))
elif args.benchmark == "argmax":
print(bench(reduction, "argmax", axis, x))
elif args.benchmark == "add":
print(bench(binary, "add", *xs))
elif args.benchmark == "mul":
print(bench(binary, "multiply", *xs))
elif args.benchmark == "softmax":
if args.fused:
print(bench(softmax_fused, axis, x))
else:
print(bench(softmax, axis, x))
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))
elif args.benchmark == "cross_entropy":
if len(size) != 2:
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
targets = mx.zeros((len(x),), dtype=mx.uint32)
print(bench(cross_entropy, targets, x))
elif args.benchmark == "logsumexp":
print(bench(logsumexp, axis, x))
elif args.benchmark == "rope":
print(bench(rope, x))
elif args.benchmark == "concatenate":
print(bench(concatenate, axis, *xs))
elif args.benchmark == "cumsum":
print(bench(cumsum, axis, *xs))
elif args.benchmark == "conv1d":
print(bench(conv1d, *xs))
elif args.benchmark == "conv2d":
print(bench(conv2d, *xs))
elif args.benchmark == "sort":
print(bench(sort, axis, x))
elif args.benchmark == "topk":
print(bench(topk, axis, x))
else:
raise ValueError("Unknown benchmark")

View File

@@ -0,0 +1,338 @@
import argparse
import os
import time
import torch
import torch.mps
def int_or_list(x):
try:
return int(x)
except ValueError:
return [int(xi) for xi in x.split(",")]
def none_or_list(x):
if x == "":
return None
else:
return [int(xi) for xi in x.split(",")]
def bench(f, *args):
for i in range(10):
f(*args)
s = time.time()
for i in range(100):
f(*args)
e = time.time()
return e - s
def sync_if_needed(x):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
@torch.no_grad()
def matmul_square(x):
y = x
for i in range(10):
y = y @ x
sync_if_needed(x)
@torch.no_grad()
def matmul(x, y):
ys = []
for i in range(10):
ys.append(x @ y)
sync_if_needed(x)
@torch.no_grad()
def conv1d(x, y):
x = torch.transpose(x, -1, -2)
y = torch.transpose(y, -1, -2)
ys = []
for i in range(10):
ys.append(torch.nn.functional.conv1d(x, y))
sync_if_needed(x)
@torch.no_grad()
def conv2d(x, y):
x = torch.permute(x, (0, 3, 1, 2))
y = torch.permute(y, (0, 3, 1, 2))
ys = []
for i in range(10):
ys.append(torch.nn.functional.conv2d(x, y))
sync_if_needed(x)
@torch.no_grad()
def binary(op, x, y):
for i in range(100):
y = getattr(torch, op)(x, y)
sync_if_needed(x)
@torch.no_grad()
def reduction(op, axis, x):
ys = []
for i in range(100):
ys.append(getattr(x, op)(axis))
sync_if_needed(x)
@torch.no_grad()
def softmax(axis, x):
ys = []
for i in range(100):
ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values)
y = ex / torch.sum(ex, dim=axis, keepdims=True)
ys.append(y)
sync_if_needed(x)
@torch.no_grad()
def softmax_fused(axis, x):
ys = []
for i in range(100):
ys.append(torch.nn.functional.softmax(x, dim=axis))
sync_if_needed(x)
@torch.no_grad()
def relu(x):
y = x
for i in range(100):
y = torch.nn.functional.relu(y)
sync_if_needed(x)
@torch.no_grad()
def scalar_mult(x):
y = x
for i in range(100):
y = y * (1.0 / (1 + i))
sync_if_needed(x)
@torch.no_grad()
def cross_entropy(targets, x):
ys = []
for i in range(100):
ys.append(torch.nn.functional.cross_entropy(x, targets))
sync_if_needed(x)
@torch.no_grad()
def logsumexp(axis, x):
ys = []
for i in range(100):
ys.append(torch.logsumexp(x, dim=axis))
sync_if_needed(x)
@torch.no_grad()
def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(torch.nn.functional.linear(x, w, b))
sync_if_needed(x)
@torch.no_grad()
def linear(w, b, x):
ys = []
for i in range(10):
ys.append((x @ torch.transpose(w, -2, -1)) + b)
sync_if_needed(x)
@torch.no_grad()
def rope(x):
*_, N, D = x.shape
ys = []
for i in range(10):
x = x.view(-1, N, D)
positions = torch.arange(N, device=x.device)
freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device)
theta = positions[:, None] * freqs[None]
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
y = y.reshape(-1, N, D)
ys.append(y)
sync_if_needed(x)
@torch.no_grad()
def concatenate(axis, x, y):
ys = []
for i in range(10):
ys.append(torch.cat([x, y], dim=axis))
sync_if_needed(x)
@torch.no_grad()
def cumsum(axis, x):
ys = []
for i in range(10):
ys.append(x.cumsum(axis))
sync_if_needed(x)
@torch.no_grad()
def sort(axis, x):
ys = []
for i in range(10):
ys.append(torch.sort(x, dim=axis)[0])
sync_if_needed(x)
@torch.no_grad()
def topk(axis, x):
k = x.shape[axis] // 3
ys = []
for i in range(10):
ys.append(torch.topk(x, k, dim=axis)[0])
sync_if_needed(x)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
parser.add_argument(
"--size",
default=[(1024, 1024)],
type=lambda x: list(map(int, x.split("x"))),
help="Set the matrix size",
action="append",
)
parser.add_argument(
"--axis",
default=[1],
type=int_or_list,
help="Set a reduction axis",
action="append",
)
parser.add_argument(
"--transpose",
type=none_or_list,
default=[],
help="Permute the matrix",
action="append",
)
parser.add_argument(
"--print-pid", action="store_true", help="Print the PID and pause"
)
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
args = parser.parse_args()
if len(args.size) > 1:
args.size.pop(0)
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
xs = []
for size in args.size:
xs.append(torch.randn(*size).to(device).to(dtype))
for i, t in enumerate(args.transpose):
if t is None:
continue
xs[i] = xs[i].permute(*t)
x = xs[0]
axis = args.axis[0]
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark == "linear":
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))
elif args.benchmark == "sum_all":
print(bench(reduction, "sum", None, x))
elif args.benchmark == "argmax":
print(bench(reduction, "argmax", axis, x))
elif args.benchmark == "add":
print(bench(binary, "add", *xs))
elif args.benchmark == "mul":
print(bench(binary, "mul", *xs))
elif args.benchmark == "softmax":
if args.fused:
print(bench(softmax_fused, axis, x))
else:
print(bench(softmax, axis, x))
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))
elif args.benchmark == "cross_entropy":
if len(size) != 2:
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
targets = torch.zeros(len(x), dtype=torch.long).to(x.device)
print(bench(cross_entropy, targets, x))
elif args.benchmark == "logsumexp":
print(bench(logsumexp, axis, x))
elif args.benchmark == "rope":
print(bench(rope, x))
elif args.benchmark == "concatenate":
print(bench(concatenate, axis, *xs))
elif args.benchmark == "cumsum":
print(bench(cumsum, axis, *xs))
elif args.benchmark == "conv1d":
print(bench(conv1d, *xs))
elif args.benchmark == "conv2d":
print(bench(conv2d, *xs))
elif args.benchmark == "sort":
print(bench(sort, axis, x))
elif args.benchmark == "topk":
print(bench(topk, axis, x))
else:
raise ValueError("Unknown benchmark")

View File

@@ -0,0 +1,253 @@
#!/usr/bin/env python
import argparse
import re
from pathlib import Path
from subprocess import run
BENCH_MLX = Path(__file__).parent / "bench_mlx.py"
BENCH_TORCH = Path(__file__).parent / "bench_torch.py"
def run_or_raise(*args, **kwargs):
try:
result = run(*args, capture_output=True, **kwargs)
return float(result.stdout)
except ValueError:
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
def compare(args):
t_mlx = run_or_raise(["python", BENCH_MLX] + args)
t_torch = run_or_raise(["python", BENCH_TORCH] + args)
print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t")
def compare_mlx_dtypes(args, dt1, dt2):
t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1])
t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2])
print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, " ".join(args), sep="\t")
def make_regex_search(regexes):
compiled_regexes = list(map(re.compile, regexes))
def search(x):
return (c.search(x) is not None for c in compiled_regexes)
return search
def make_predicate(positive_filter, negative_filter):
if positive_filter is not None:
positive_filter_search = make_regex_search(positive_filter)
positive_filter = lambda x: all(positive_filter_search(x))
else:
positive_filter = lambda x: True
if negative_filter is not None:
negative_filter_search = make_regex_search(negative_filter)
negative_filter = lambda x: not any(negative_filter_search(x))
else:
negative_filter = lambda x: True
def predicate(x):
return positive_filter(x) and negative_filter(x)
return predicate
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
parser.add_argument(
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
)
parser.add_argument(
"--negative_filter", "-n", help="Regex filter to remove benchmarks", nargs="+"
)
parser.add_argument(
"--mlx_dtypes",
"-d",
help="Compare mlx benchmarks between the 2 provided data types",
nargs=2,
)
args, rest = parser.parse_known_args()
_filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes:
compare_filtered = (
lambda x: compare_mlx_dtypes(
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
if _filter(x)
else None
)
else:
compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None
# Binary ops
compare_filtered("add --size 10x1024x128 --size 1x1024x128 --cpu")
compare_filtered("add --size 10x1024x128 --size 1x1024x128")
compare_filtered("add --size 1024x128 --size 1x128 --cpu")
compare_filtered("add --size 1024x128 --size 1x128")
compare_filtered("add --size 1024x4096 --size 1x4096 --cpu")
compare_filtered("add --size 1024x4096 --size 1x4096")
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0 --cpu")
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0")
compare_filtered("add --size 1024x1024 --size 1024x1024 --cpu")
compare_filtered("add --size 1024x1024 --size 1024x1024")
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0 --cpu")
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0")
compare_filtered(
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0 --cpu"
)
compare_filtered(
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0"
)
# Reduction ops
compare_filtered("sum_all --size 10x1024x128 --cpu")
compare_filtered("sum_all --size 10x1024x128")
compare_filtered("sum_axis --size 16x1024x128 --axis 2 --cpu")
compare_filtered("sum_axis --size 16x1024x128 --axis 2")
compare_filtered("sum_axis --size 16x128x1024 --axis 2 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 2")
compare_filtered("sum_axis --size 1024x1024 --axis 1 --cpu")
compare_filtered("sum_axis --size 1024x1024 --axis 1")
compare_filtered("sum_axis --size 1024x1024 --axis 0 --cpu")
compare_filtered("sum_axis --size 1024x1024 --axis 0")
compare_filtered("sum_axis --size 16x128x1024 --axis 1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
compare_filtered("argmax --size 10x1024x128 --axis 1")
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
compare_filtered("argmax --size 10x1024x128 --axis 2")
compare_filtered("argmax --size 1024x1024 --axis 1 --cpu")
compare_filtered("argmax --size 1024x1024 --axis 1")
# Matmul ops
compare_filtered("matmul_square --size 1024x1024")
compare_filtered("matmul_square --size 1024x1024 --cpu")
compare_filtered("matmul_square --size 16x1024x1024")
compare_filtered("matmul_square --size 16x1024x1024 --cpu")
compare_filtered(
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1"
)
compare_filtered(
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1 --cpu"
)
compare_filtered(
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1"
)
compare_filtered(
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1 --cpu"
)
compare_filtered("matmul --size 512x8192 --size 8192x512")
compare_filtered("matmul --size 512x8192 --size 8192x512 --cpu")
# compare_filtered("matmul --size 512x131072 --size 131072x512")
# compare_filtered("matmul --size 512x131072 --size 131072x512 --cpu")
compare_filtered("matmul --size 8192x512 --size 512x8192")
compare_filtered("matmul --size 8192x512 --size 512x8192 --cpu")
# compare_filtered("matmul --size 131072x512 --size 512x512")
# compare_filtered("matmul --size 131072x512 --size 512x512 --cpu")
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024")
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --cpu")
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --fused")
compare_filtered(
"linear --size 1024x1024 --size 1024 --size 128x1024 --fused --cpu"
)
# Matvec ops
compare_filtered("matmul --size 1x1x4096 --size 4096x4096 --cpu")
compare_filtered("matmul --size 1x1x4096 --size 4096x4096")
compare_filtered(
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0 --cpu"
)
compare_filtered(
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0"
)
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128 --cpu")
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128")
compare_filtered(
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1 --cpu"
)
compare_filtered(
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1"
)
# Various ops
compare_filtered("softmax --size 32x16x1024 --axis 2")
compare_filtered("softmax --size 32x16x1024 --axis 2 --cpu")
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused")
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused --cpu")
compare_filtered("softmax --size 2x1024x1024 --axis 1")
compare_filtered("softmax --size 2x1024x1024 --axis 1 --cpu")
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused")
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
compare_filtered("relu --size 32x16x1024")
compare_filtered("relu --size 32x16x1024 --cpu")
compare_filtered("scalar_mul --size 32x16x1024")
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
compare_filtered("cross_entropy --size 256x1024")
compare_filtered("cross_entropy --size 256x1024 --cpu")
compare_filtered("logsumexp --size 1024x1024 --axis 1")
compare_filtered("logsumexp --size 1024x1024 --axis 1 --cpu")
compare_filtered("logsumexp --size 1024x1024 --axis 0")
compare_filtered("logsumexp --size 1024x1024 --axis 0 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1")
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1")
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1 --cpu")
compare_filtered("concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2")
compare_filtered(
"concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2 --cpu"
)
compare_filtered("conv1d --size 1x1000x80 --size 128x11x80")
compare_filtered("conv1d --size 1x1000x80 --size 128x11x80 --cpu")
compare_filtered("conv1d --size 16x1000x80 --size 128x11x80")
compare_filtered("conv1d --size 4x1000x80 --size 128x11x80 --cpu")
compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3")
compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3 --cpu")
compare_filtered("conv2d --size 16x256x256x3 --size 8x3x3x3")
compare_filtered("conv2d --size 4x256x256x3 --size 8x3x3x3 --cpu")
compare_filtered("cumsum --size 1024x1024 --axis 1 --cpu")
compare_filtered("cumsum --size 1024x1024 --axis 0 --cpu")
compare_filtered("cumsum --size 1024x1024 --axis 1")
compare_filtered("cumsum --size 1024x1024 --axis 0")
compare_filtered("cumsum --size 128x1024 --axis 1")
compare_filtered("cumsum --size 128x1024 --axis 0")
compare_filtered("cumsum --size 1024x4096 --axis 1")
compare_filtered("cumsum --size 1024x4096 --axis 0")
compare_filtered("cumsum --size 128x4096 --axis 1")
compare_filtered("cumsum --size 128x4096 --axis 0")
compare_filtered("cumsum --size 1024x7777 --axis 1")
compare_filtered("cumsum --size 1024x7777 --axis 0")
compare_filtered("cumsum --size 128x7777 --axis 1")
compare_filtered("cumsum --size 128x7777 --axis 0")
compare_filtered("cumsum --size 32768x128 --axis 1")
compare_filtered("cumsum --size 32768x128 --axis 0")
compare_filtered("sort --size 1024x1024 --axis 0")
compare_filtered("sort --size 1024x1024 --axis 1")
compare_filtered("sort --size 32768x128 --axis 0")
compare_filtered("sort --size 32768x128 --axis 1")
compare_filtered("sort --size 128x128 --axis 0 --cpu")
compare_filtered("sort --size 128x128 --axis 1 --cpu")
compare_filtered("topk --size 1024x1024 --axis 0")
compare_filtered("topk --size 1024x1024 --axis 1")
compare_filtered("topk --size 32768x128 --axis 0")
compare_filtered("topk --size 32768x128 --axis 1")
compare_filtered("topk --size 128x128 --axis 0 --cpu")
compare_filtered("topk --size 128x128 --axis 1 --cpu")

View File

@@ -0,0 +1,196 @@
import math
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
class RoPE(nn.Module):
dims: int
traditional: bool = False
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = jnp.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=jnp.float32,
):
D = D // 2
positions = jnp.arange(offset, N, dtype=dtype)
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
return costheta, sintheta
@nn.compact
def __call__(self, x, offset: int = 0):
shape = x.shape
x = x.reshape((-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.reshape(shape)
class LlamaAttention(nn.Module):
dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
num_heads = self.num_heads
dims = self.dims
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = jnp.concatenate([key_cache, keys], axis=2)
values = jnp.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = jax.nn.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
dims: int
mlp_dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
dims = self.dims
mlp_dims = self.mlp_dims
num_heads = self.num_heads
self.attention = LlamaAttention(dims, num_heads, dtype)
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = jax.nn.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
dtype = jnp.float16
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
x = jax.random.normal(k1, (1, 1, D), dtype)
cache = [
jax.random.normal(k2, [1, H, C, D // H], dtype),
jax.random.normal(k3, [1, H, C, D // H], dtype),
]
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
params = layer.init(k4, x, mask=None, cache=cache)["params"]
@jax.jit
def model_fn(x, mask, cache):
return layer.apply({"params": params}, x, mask=mask, cache=cache)
T = measure(model_fn, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,197 @@
import math
import time
import torch
import torch.nn as nn
import torch.mps
def sync_if_needed(x):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
class RoPE(nn.Module):
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
else:
rx = torch.cat([rx1, rx2], dim=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
return rx
def forward(self, x, offset: int = 0):
shape = x.shape
x = x.view(-1, shape[-2], shape[-1])
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.view(*shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
device="cpu",
dtype=torch.float32,
):
D = D // 2
positions = torch.arange(offset, N, dtype=dtype, device=device)
freqs = torch.exp(
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
)
theta = positions.view(-1, 1) * freqs.view(1, -1)
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
return costheta, sintheta
class RMSNorm(nn.Module):
def __init__(self, dims: int, epsilon: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones((dims,)))
self.epsilon = epsilon
def forward(self, x):
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
return self.gamma * x * n
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def forward(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = torch.cat([key_cache, keys], dim=2)
values = torch.cat([value_cache, values], dim=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = torch.softmax(scores, dim=-1)
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = RMSNorm(dims)
self.norm2 = RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def forward(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = torch.nn.functional.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
@torch.no_grad()
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
device = torch.device("mps")
dtype = torch.float16
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
x = torch.randn(1, 1, D).to(device).to(dtype)
cache = [
torch.randn(1, H, C, D // H).to(device).to(dtype),
torch.randn(1, H, C, D // H).to(device).to(dtype),
]
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,106 @@
import argparse
import mlx.core as mx
from time_utils import time_fn
def time_add():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.add, a, b)
aT = mx.transpose(a, [0, 2, 1])
mx.eval(aT)
def transpose_add(a, b):
return mx.add(a, b)
time_fn(transpose_add, aT, b)
b = mx.random.uniform(shape=(1024,))
mx.eval(b)
def slice_add(a, b):
return mx.add(a, b)
time_fn(slice_add, a, b)
b = mx.reshape(b, (1, 1024, 1))
mx.eval(b)
def mid_slice_add(a, b):
return mx.add(a, b)
time_fn(mid_slice_add, a, b)
def time_matmul():
a = mx.random.uniform(shape=(1024, 1024))
b = mx.random.uniform(shape=(1024, 1024))
mx.eval(a, b)
time_fn(mx.matmul, a, b)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
def negative(a):
return -a
mx.eval(a)
time_fn(negative, a)
def time_exp():
a = mx.random.uniform(shape=(1000, 100))
mx.eval(a)
time_fn(mx.exp, a)
def time_logsumexp():
a = mx.random.uniform(shape=(64, 10, 10000))
mx.eval(a)
time_fn(mx.logsumexp, a, axis=-1)
def time_take():
a = mx.random.uniform(shape=(10000, 500))
ids = mx.random.randint(low=0, high=10000, shape=(20, 10))
ids = [mx.reshape(idx, (-1,)) for idx in ids]
mx.eval(ids)
def random_take():
return [mx.take(a, idx, 0) for idx in ids]
time_fn(random_take)
def time_reshape_transposed():
x = mx.random.uniform(shape=(256, 256, 128))
mx.eval(x)
def reshape_transposed():
return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))
time_fn(reshape_transposed)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_add()
time_matmul()
time_exp()
time_negative()
time_logsumexp()
time_take()
time_reshape_transposed()