jagrit's commit files

This commit is contained in:
Jagrit Digani
2023-11-29 10:52:08 -08:00
parent d1f86272a2
commit e6306cfee9
74 changed files with 15964 additions and 2 deletions

View File

@@ -0,0 +1,190 @@
import numpy as np
import argparse
import mlx.core as mx
import time
import torch
import os
import math
import subprocess
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 8
N_iter_bench = 80
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def gemm_nn_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b
ys.append(y)
mx.eval(ys)
return ys
def gemm_nt_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b.transpose((0, 2, 1))
ys.append(y)
mx.eval(ys)
return ys
def gemm_tn_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose((0, 2, 1)) @ b
ys.append(y)
mx.eval(ys)
return ys
def gemm_tt_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1))
ys.append(y)
mx.eval(ys)
return ys
@torch.no_grad()
def gemm_nn_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemm_nt_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b.transpose(-1, -2)
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemm_tn_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose(-1, -2) @ b
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemm_tt_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose(-1, -2) @ b.transpose(-1, -2)
ys.append(y)
torch.mps.synchronize()
return ys
def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M)
shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K)
a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype)
b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np).to("mps")
b_pt = torch.from_numpy(b_np).to("mps")
torch.mps.synchronize()
f_mx = {
"nn": gemm_nn_mlx,
"nt": gemm_nt_mlx,
"tn": gemm_tn_mlx,
"tt": gemm_tt_mlx,
}[transpose]
f_pt = {
"nn": gemm_nn_torch,
"nt": gemm_nt_torch,
"tn": gemm_tn_torch,
"tt": gemm_tt_torch,
}[transpose]
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1)
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol):
print(
f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
)
return time_mlx, time_torch
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
(15, 1023, 1023, 1023),
(17, 1025, 1025, 1025),
)
for dtype in dtypes:
for transpose in transposes:
for B, M, N, K in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)
gflop_count = get_gflop_count(B, M, N, K)
gflops_mx = gflop_count / (time_mlx)
gflops_pt = gflop_count / (time_torch)
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,219 @@
import matplotlib.pyplot as plt
import numpy as np
import argparse
import mlx.core as mx
import time
import torch
import os
import subprocess
results_dir = "./results"
if not os.path.isdir(results_dir):
os.mkdir(results_dir)
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 50
N_iter_func = 20
out_vec_sizes = [128, 512, 2048, 4096]
in_vec_sizes = [128, 512, 2048, 4096]
benchmark_vector_lens = []
benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2]
benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2]
benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2]
benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000]
benchmark_vector_lens.sort()
def bench(f, m, v):
for i in range(N_warmup):
f(m, v)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(m, v)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def gemv_mlx(m, v):
ys = []
for i in range(N_iter_func):
y = m @ v
ys.append(y)
mx.eval(ys)
return ys
def gemv_t_mlx(m, v):
ys = []
for i in range(N_iter_func):
y = v @ m
ys.append(y)
mx.eval(ys)
return ys
@torch.no_grad()
def gemv_torch(m, v):
ys = []
for i in range(N_iter_func):
y = m @ v
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemv_t_torch(m, v):
ys = []
for i in range(N_iter_func):
y = v @ m
ys.append(y)
torch.mps.synchronize()
return ys
def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):
shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len)
shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1)
mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype)
vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype)
mat_mlx = mx.array(mat_npy)
vec_mlx = mx.array(vec_npy)
mat_trc = torch.from_numpy(mat_npy).to("mps")
vec_trc = torch.from_numpy(vec_npy).to("mps")
torch.mps.synchronize()
time_torch = (
bench(gemv_t_torch, mat_trc, vec_trc)
if transpose
else bench(gemv_torch, mat_trc, vec_trc)
)
time_mlx = (
bench(gemv_t_mlx, mat_mlx, vec_mlx)
if transpose
else bench(gemv_mlx, mat_mlx, vec_mlx)
)
c_mlx = (
np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx)
)
c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy)
if not np.allclose(c_mlx, c_npy, atol=2e-5):
print(
f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
)
return time_mlx, time_torch
def get_gflop_count(in_vec_len, out_vec_len):
return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float(
1024**3
)
def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len
item_size = 4 if np_dtype == np.float32 else 2
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []
pyt_gb_s = []
pyt_gflops = []
for out_vec_len in out_vector_lens:
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
mlx_gb_s.append(gbyte_size / time_mlx)
pyt_gb_s.append(gbyte_size / time_torch)
mlx_gflops.append(gflop_count / time_mlx)
pyt_gflops.append(gflop_count / time_torch)
if transpose:
title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}"
else:
title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}"
ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch")
ax.set_title(title)
ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)")
ax.legend()
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []
pyt_gb_s = []
pyt_gflops = []
for in_vec_len in in_vector_lens:
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
mlx_gb_s.append(gbyte_size / time_mlx)
pyt_gb_s.append(gbyte_size / time_torch)
mlx_gflops.append(gflop_count / time_mlx)
pyt_gflops.append(gflop_count / time_torch)
if transpose:
title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])"
else:
title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )"
ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch")
ax.set_title(title)
ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)")
ax.legend()
for transpose in (False, True):
for dtype in ("float32", "float16"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
for i, in_vec_len in enumerate(in_vec_sizes):
bench_with_in_len(
axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose
)
for i, out_vec_len in enumerate(out_vec_sizes):
bench_with_out_len(
axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose
)
op_name = "gemv_t" if transpose else "gemv"
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
)
)
plt.close(fig)