mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 09:07:12 +08:00
Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <katharas@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
213 lines
6.1 KiB
Python
213 lines
6.1 KiB
Python
import math
|
|
import os
|
|
import subprocess
|
|
import time
|
|
from copy import copy
|
|
from functools import partial
|
|
|
|
import matplotlib.pyplot as plt
|
|
import mlx.core as mx
|
|
import numpy as np
|
|
import torch
|
|
from matplotlib.ticker import FuncFormatter
|
|
|
|
RESULTS_DIR = "./results"
|
|
|
|
|
|
if not os.path.isdir(RESULTS_DIR):
|
|
os.mkdir(RESULTS_DIR)
|
|
|
|
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
|
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
|
|
|
|
TORCH_DEVICE = torch.device(
|
|
"mps"
|
|
if torch.backends.mps.is_available()
|
|
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
)
|
|
|
|
|
|
N_WARMUP = 5
|
|
N_ITER_BENCH = 50
|
|
N_ITER_FUNC = 20
|
|
|
|
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
|
|
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
|
|
D_TYPES = ("float32", "float16")
|
|
|
|
|
|
def _power_of_two_formatter(value, _position):
|
|
if value <= 0:
|
|
return ""
|
|
exponent = int(round(math.log2(value)))
|
|
if abs(value - (1 << exponent)) / value > 1e-6:
|
|
return f"{value:g}"
|
|
return f"$2^{{{exponent}}}$"
|
|
|
|
|
|
def torch_sync():
|
|
if TORCH_DEVICE.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
elif TORCH_DEVICE.type == "mps":
|
|
torch.mps.synchronize()
|
|
|
|
|
|
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
|
|
outs = []
|
|
for _ in range(N_ITER_FUNC):
|
|
out = copy(self_arr)
|
|
out[mask_arr] = src_arr
|
|
outs.append(out)
|
|
mx.eval(outs)
|
|
return outs
|
|
|
|
|
|
@torch.no_grad()
|
|
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
|
|
outs = []
|
|
for _ in range(N_ITER_FUNC):
|
|
out = self_tensor.clone()
|
|
out.masked_scatter_(mask_tensor, src_tensor)
|
|
outs.append(out)
|
|
torch_sync()
|
|
return outs
|
|
|
|
|
|
def measure(fn):
|
|
for _ in range(N_WARMUP):
|
|
fn()
|
|
start = time.perf_counter_ns()
|
|
for _ in range(N_ITER_BENCH):
|
|
fn()
|
|
end = time.perf_counter_ns()
|
|
return (end - start) * 1e-9
|
|
|
|
|
|
def bytes_touched(length, true_count, item_size):
|
|
mask_bytes = length
|
|
self_bytes = length * item_size * 2 # read + write
|
|
src_bytes = true_count * item_size
|
|
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
|
|
|
|
|
|
def build_case(length, density, np_dtype, torch_dtype):
|
|
true_count = max(1, int(round(length * density)))
|
|
|
|
rng = np.random.default_rng()
|
|
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
|
|
mask_np = np.zeros(length, dtype=bool)
|
|
mask_np[:true_count] = True
|
|
rng.shuffle(mask_np)
|
|
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
|
|
|
|
self_mlx = mx.array(self_np)
|
|
mask_mlx = mx.array(mask_np)
|
|
src_mlx = mx.array(src_np)
|
|
|
|
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
|
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
|
|
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
|
|
|
# Correctness check once per configuration
|
|
mx_out = mx.array(self_np)
|
|
mx_out[mask_mlx] = src_mlx
|
|
mx.eval(mx_out)
|
|
torch_out = self_torch.clone()
|
|
torch_out.masked_scatter_(mask_torch, src_torch)
|
|
|
|
atol = 5e-3 if np_dtype == np.float16 else 1e-5
|
|
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
|
|
raise AssertionError("masked_scatter results diverged between MLX and Torch")
|
|
|
|
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
|
|
|
|
|
|
def bench_case(length, density, dtype):
|
|
np_dtype = getattr(np, dtype)
|
|
torch_dtype = getattr(torch, dtype)
|
|
(
|
|
self_mlx,
|
|
mask_mlx,
|
|
src_mlx,
|
|
self_torch,
|
|
mask_torch,
|
|
src_torch,
|
|
true_count,
|
|
) = build_case(length, density, np_dtype, torch_dtype)
|
|
|
|
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
|
|
time_torch = measure(
|
|
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
|
|
)
|
|
|
|
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
|
|
bytes_per_gb = float(1024**3)
|
|
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
|
|
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
|
|
|
|
return time_mlx, time_torch, mlx_gbps, torch_gbps
|
|
|
|
|
|
def plot_density(ax_perf, ax_speedup, density, dtype):
|
|
mlx_gbps = []
|
|
torch_gbps = []
|
|
mlx_times = []
|
|
torch_times = []
|
|
|
|
for length in VECTOR_LENGTHS:
|
|
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
|
|
mlx_gbps.append(gbps_mlx)
|
|
torch_gbps.append(gbps_torch)
|
|
mlx_times.append(t_mlx)
|
|
torch_times.append(t_torch)
|
|
|
|
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
|
|
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
|
|
ax_perf.set_xscale("log", base=2)
|
|
ax_perf.set_xticks(VECTOR_LENGTHS)
|
|
formatter = FuncFormatter(_power_of_two_formatter)
|
|
ax_perf.xaxis.set_major_formatter(formatter)
|
|
ax_perf.set_title(f"density={density:.2f}")
|
|
ax_perf.set_ylabel("GB/s")
|
|
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
|
|
ax_perf.legend()
|
|
|
|
speedup = np.array(torch_times) / np.array(mlx_times)
|
|
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
|
|
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
|
|
ax_speedup.set_xscale("log", base=2)
|
|
ax_speedup.set_xticks(VECTOR_LENGTHS)
|
|
ax_speedup.xaxis.set_major_formatter(formatter)
|
|
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
|
|
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
|
|
|
|
|
|
def main():
|
|
for dtype in D_TYPES:
|
|
fig, axs = plt.subplots(
|
|
len(MASK_DENSITIES),
|
|
2,
|
|
figsize=(10, 12),
|
|
layout="constrained",
|
|
sharex=True,
|
|
)
|
|
|
|
for i, density in enumerate(MASK_DENSITIES):
|
|
plot_density(axs[i][0], axs[i][1], density, dtype)
|
|
axs[i][0].set_xlabel("vector length")
|
|
axs[i][1].set_xlabel("vector length")
|
|
|
|
fig.suptitle(
|
|
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
|
|
)
|
|
output_path = os.path.join(
|
|
RESULTS_DIR,
|
|
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
|
|
)
|
|
fig.savefig(output_path)
|
|
plt.close(fig)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|