mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-28 22:28:11 +08:00
jagrit's commit files
This commit is contained in:
11
benchmarks/cpp/CMakeLists.txt
Normal file
11
benchmarks/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
function(build_benchmark SRCFILE)
|
||||
get_filename_component(src_name ${SRCFILE} NAME_WE)
|
||||
set(target "${src_name}")
|
||||
add_executable(${target} ${SRCFILE})
|
||||
target_link_libraries(${target} PRIVATE mlx)
|
||||
endfunction(build_benchmark)
|
||||
|
||||
build_benchmark(single_ops.cpp)
|
||||
build_benchmark(irregular_strides.cpp)
|
||||
build_benchmark(compare_devices.cpp)
|
||||
build_benchmark(autograd.cpp)
|
||||
37
benchmarks/cpp/autograd.cpp
Normal file
37
benchmarks/cpp/autograd.cpp
Normal file
@@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_value_and_grad() {
|
||||
auto x = ones({200, 1000});
|
||||
eval(x);
|
||||
auto fn = [](array x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = log(exp(x));
|
||||
}
|
||||
return sum(x);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(fn);
|
||||
auto independent_value_and_grad = [&]() {
|
||||
auto value = fn(x);
|
||||
auto dfdx = grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
};
|
||||
TIME(independent_value_and_grad);
|
||||
|
||||
auto value_and_grad_fn = value_and_grad(fn);
|
||||
auto combined_value_and_grad = [&]() {
|
||||
auto [value, dfdx] = value_and_grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
};
|
||||
TIME(combined_value_and_grad);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_value_and_grad();
|
||||
}
|
||||
25
benchmarks/cpp/compare_devices.cpp
Normal file
25
benchmarks/cpp/compare_devices.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include <iostream>
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_add_op() {
|
||||
std::vector<int> sizes(1, 1);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
sizes.push_back(10 * sizes.back());
|
||||
}
|
||||
set_default_device(Device::cpu);
|
||||
for (auto size : sizes) {
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
std::cout << "Size " << size << std::endl;
|
||||
TIMEM("cpu", add, a, b, Device::cpu);
|
||||
TIMEM("gpu", add, a, b, Device::gpu);
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
time_add_op();
|
||||
}
|
||||
38
benchmarks/numpy/single_ops.py
Normal file
38
benchmarks/numpy/single_ops.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def time_add():
|
||||
a = np.ones((100, 100, 10), dtype=np.float32)
|
||||
b = np.ones((100, 100, 10), dtype=np.float32)
|
||||
time_fn(np.add, a, b)
|
||||
|
||||
|
||||
def time_matmul():
|
||||
a = np.random.rand(1000, 500).astype(np.float32)
|
||||
b = np.random.rand(500, 1000).astype(np.float32)
|
||||
time_fn(np.matmul, a, b)
|
||||
|
||||
|
||||
def time_exp():
|
||||
a = np.random.randn(1000, 100).astype(np.float32)
|
||||
time_fn(np.exp, a)
|
||||
|
||||
|
||||
def time_take():
|
||||
a = np.random.rand(10000, 500)
|
||||
ids = np.random.randint(0, 10000, (20, 10))
|
||||
ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
|
||||
|
||||
def random_take():
|
||||
return [np.take(a, idx, 0) for idx in ids]
|
||||
|
||||
time_fn(random_take)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_exp()
|
||||
time_take()
|
||||
18
benchmarks/numpy/time_utils.py
Normal file
18
benchmarks/numpy/time_utils.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import time
|
||||
|
||||
|
||||
def time_fn(fn, *args):
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
fn(*args)
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = fn(*args)
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
print(f"{msec:.5f} msec")
|
||||
190
benchmarks/python/blas/bench_gemm.py
Normal file
190
benchmarks/python/blas/bench_gemm.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import numpy as np
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
import time
|
||||
import torch
|
||||
import os
|
||||
import math
|
||||
import subprocess
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 8
|
||||
N_iter_bench = 80
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def gemm_nn_mlx(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a @ b
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
def gemm_nt_mlx(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a @ b.transpose((0, 2, 1))
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
def gemm_tn_mlx(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a.transpose((0, 2, 1)) @ b
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
def gemm_tt_mlx(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1))
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemm_nn_torch(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a @ b
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemm_nt_torch(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a @ b.transpose(-1, -2)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemm_tn_torch(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a.transpose(-1, -2) @ b
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemm_tt_torch(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a.transpose(-1, -2) @ b.transpose(-1, -2)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
||||
shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M)
|
||||
shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K)
|
||||
|
||||
a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype)
|
||||
b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np).to("mps")
|
||||
b_pt = torch.from_numpy(b_np).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = {
|
||||
"nn": gemm_nn_mlx,
|
||||
"nt": gemm_nt_mlx,
|
||||
"tn": gemm_tn_mlx,
|
||||
"tt": gemm_tt_mlx,
|
||||
}[transpose]
|
||||
|
||||
f_pt = {
|
||||
"nn": gemm_nn_torch,
|
||||
"nt": gemm_nt_torch,
|
||||
"tn": gemm_tn_torch,
|
||||
"tt": gemm_tt_torch,
|
||||
}[transpose]
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1)
|
||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||
|
||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol):
|
||||
print(
|
||||
f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
dtypes = ("float32", "float16")
|
||||
transposes = ("nn", "nt", "tn")
|
||||
shapes = (
|
||||
(16, 1024, 1024, 1024),
|
||||
(1, 1024, 1024, 2048),
|
||||
(4, 1024, 1024, 4096),
|
||||
(4, 1024, 4096, 1024),
|
||||
(1, 4096, 4096, 4096),
|
||||
(15, 1023, 1023, 1023),
|
||||
(17, 1025, 1025, 1025),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, M, N, K in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)
|
||||
|
||||
gflop_count = get_gflop_count(B, M, N, K)
|
||||
gflops_mx = gflop_count / (time_mlx)
|
||||
gflops_pt = gflop_count / (time_torch)
|
||||
diff = gflops_mx / gflops_pt - 1.0
|
||||
|
||||
print(
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if gflops_pt >= 2.0 * gflops_mx:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
219
benchmarks/python/blas/bench_gemv.py
Normal file
219
benchmarks/python/blas/bench_gemv.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
import time
|
||||
import torch
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
results_dir = "./results"
|
||||
|
||||
if not os.path.isdir(results_dir):
|
||||
os.mkdir(results_dir)
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 5
|
||||
N_iter_bench = 50
|
||||
N_iter_func = 20
|
||||
|
||||
out_vec_sizes = [128, 512, 2048, 4096]
|
||||
in_vec_sizes = [128, 512, 2048, 4096]
|
||||
|
||||
benchmark_vector_lens = []
|
||||
benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2]
|
||||
benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2]
|
||||
benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2]
|
||||
benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000]
|
||||
|
||||
benchmark_vector_lens.sort()
|
||||
|
||||
|
||||
def bench(f, m, v):
|
||||
for i in range(N_warmup):
|
||||
f(m, v)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(m, v)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def gemv_mlx(m, v):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = m @ v
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
def gemv_t_mlx(m, v):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = v @ m
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemv_torch(m, v):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = m @ v
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemv_t_torch(m, v):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = v @ m
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):
|
||||
shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len)
|
||||
shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1)
|
||||
|
||||
mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype)
|
||||
vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype)
|
||||
mat_mlx = mx.array(mat_npy)
|
||||
vec_mlx = mx.array(vec_npy)
|
||||
mat_trc = torch.from_numpy(mat_npy).to("mps")
|
||||
vec_trc = torch.from_numpy(vec_npy).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
time_torch = (
|
||||
bench(gemv_t_torch, mat_trc, vec_trc)
|
||||
if transpose
|
||||
else bench(gemv_torch, mat_trc, vec_trc)
|
||||
)
|
||||
time_mlx = (
|
||||
bench(gemv_t_mlx, mat_mlx, vec_mlx)
|
||||
if transpose
|
||||
else bench(gemv_mlx, mat_mlx, vec_mlx)
|
||||
)
|
||||
|
||||
c_mlx = (
|
||||
np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx)
|
||||
)
|
||||
c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy)
|
||||
|
||||
if not np.allclose(c_mlx, c_npy, atol=2e-5):
|
||||
print(
|
||||
f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
def get_gflop_count(in_vec_len, out_vec_len):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float(
|
||||
1024**3
|
||||
)
|
||||
|
||||
|
||||
def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
|
||||
n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len
|
||||
item_size = 4 if np_dtype == np.float32 else 2
|
||||
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
||||
|
||||
|
||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
pyt_gb_s = []
|
||||
pyt_gflops = []
|
||||
|
||||
for out_vec_len in out_vector_lens:
|
||||
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
|
||||
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
|
||||
|
||||
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
|
||||
|
||||
mlx_gb_s.append(gbyte_size / time_mlx)
|
||||
pyt_gb_s.append(gbyte_size / time_torch)
|
||||
|
||||
mlx_gflops.append(gflop_count / time_mlx)
|
||||
pyt_gflops.append(gflop_count / time_torch)
|
||||
|
||||
if transpose:
|
||||
title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}"
|
||||
else:
|
||||
title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}"
|
||||
|
||||
ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
|
||||
ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch")
|
||||
ax.set_title(title)
|
||||
ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)")
|
||||
ax.legend()
|
||||
|
||||
|
||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
pyt_gb_s = []
|
||||
pyt_gflops = []
|
||||
|
||||
for in_vec_len in in_vector_lens:
|
||||
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
|
||||
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
|
||||
|
||||
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
|
||||
|
||||
mlx_gb_s.append(gbyte_size / time_mlx)
|
||||
pyt_gb_s.append(gbyte_size / time_torch)
|
||||
|
||||
mlx_gflops.append(gflop_count / time_mlx)
|
||||
pyt_gflops.append(gflop_count / time_torch)
|
||||
|
||||
if transpose:
|
||||
title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])"
|
||||
else:
|
||||
title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )"
|
||||
|
||||
ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
|
||||
ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch")
|
||||
ax.set_title(title)
|
||||
ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)")
|
||||
ax.legend()
|
||||
|
||||
|
||||
for transpose in (False, True):
|
||||
for dtype in ("float32", "float16"):
|
||||
fig, axs = plt.subplots(
|
||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||
)
|
||||
|
||||
for i, in_vec_len in enumerate(in_vec_sizes):
|
||||
bench_with_in_len(
|
||||
axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose
|
||||
)
|
||||
|
||||
for i, out_vec_len in enumerate(out_vec_sizes):
|
||||
bench_with_out_len(
|
||||
axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose
|
||||
)
|
||||
|
||||
op_name = "gemv_t" if transpose else "gemv"
|
||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||
fig.savefig(
|
||||
os.path.join(
|
||||
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
||||
)
|
||||
)
|
||||
plt.close(fig)
|
||||
116
benchmarks/python/llama_mlx_bench.py
Normal file
116
benchmarks/python/llama_mlx_bench.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.rope = nn.RoPE(dims // num_heads, True)
|
||||
self.query_proj = nn.Linear(dims, dims, False)
|
||||
self.key_proj = nn.Linear(dims, dims, False)
|
||||
self.value_proj = nn.Linear(dims, dims, False)
|
||||
self.out_proj = nn.Linear(dims, dims, False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
|
||||
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
def measure(model, x, cache):
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
mx.eval(y, c)
|
||||
|
||||
start = time.time()
|
||||
rs = []
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
rs.append((y, c))
|
||||
mx.eval(rs)
|
||||
end = time.time()
|
||||
|
||||
return (end - start) * 1000 / 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
H = 32
|
||||
D = 4096
|
||||
F = 43 * 256
|
||||
C = 1000
|
||||
mx.set_default_device(mx.gpu)
|
||||
dtype = mx.float16
|
||||
|
||||
layer = LlamaEncoderLayer(D, F, H)
|
||||
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
|
||||
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
|
||||
x = mx.random.normal([1, 1, D], dtype=dtype)
|
||||
cache = [
|
||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||
]
|
||||
mx.eval(x, cache)
|
||||
|
||||
T = measure(layer, x, cache)
|
||||
|
||||
print("Time per layer per token:", T, "ms")
|
||||
print("Lower bound total time per token:", T * 32, "ms")
|
||||
Reference in New Issue
Block a user