mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 09:29:26 +08:00
Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <katharas@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
212
benchmarks/python/masked_scatter.py
Normal file
212
benchmarks/python/masked_scatter.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.ticker import FuncFormatter
|
||||
|
||||
RESULTS_DIR = "./results"
|
||||
|
||||
|
||||
if not os.path.isdir(RESULTS_DIR):
|
||||
os.mkdir(RESULTS_DIR)
|
||||
|
||||
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
|
||||
|
||||
TORCH_DEVICE = torch.device(
|
||||
"mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
)
|
||||
|
||||
|
||||
N_WARMUP = 5
|
||||
N_ITER_BENCH = 50
|
||||
N_ITER_FUNC = 20
|
||||
|
||||
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
|
||||
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
|
||||
D_TYPES = ("float32", "float16")
|
||||
|
||||
|
||||
def _power_of_two_formatter(value, _position):
|
||||
if value <= 0:
|
||||
return ""
|
||||
exponent = int(round(math.log2(value)))
|
||||
if abs(value - (1 << exponent)) / value > 1e-6:
|
||||
return f"{value:g}"
|
||||
return f"$2^{{{exponent}}}$"
|
||||
|
||||
|
||||
def torch_sync():
|
||||
if TORCH_DEVICE.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif TORCH_DEVICE.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
|
||||
outs = []
|
||||
for _ in range(N_ITER_FUNC):
|
||||
out = copy(self_arr)
|
||||
out[mask_arr] = src_arr
|
||||
outs.append(out)
|
||||
mx.eval(outs)
|
||||
return outs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
|
||||
outs = []
|
||||
for _ in range(N_ITER_FUNC):
|
||||
out = self_tensor.clone()
|
||||
out.masked_scatter_(mask_tensor, src_tensor)
|
||||
outs.append(out)
|
||||
torch_sync()
|
||||
return outs
|
||||
|
||||
|
||||
def measure(fn):
|
||||
for _ in range(N_WARMUP):
|
||||
fn()
|
||||
start = time.perf_counter_ns()
|
||||
for _ in range(N_ITER_BENCH):
|
||||
fn()
|
||||
end = time.perf_counter_ns()
|
||||
return (end - start) * 1e-9
|
||||
|
||||
|
||||
def bytes_touched(length, true_count, item_size):
|
||||
mask_bytes = length
|
||||
self_bytes = length * item_size * 2 # read + write
|
||||
src_bytes = true_count * item_size
|
||||
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
|
||||
|
||||
|
||||
def build_case(length, density, np_dtype, torch_dtype):
|
||||
true_count = max(1, int(round(length * density)))
|
||||
|
||||
rng = np.random.default_rng()
|
||||
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
|
||||
mask_np = np.zeros(length, dtype=bool)
|
||||
mask_np[:true_count] = True
|
||||
rng.shuffle(mask_np)
|
||||
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
|
||||
|
||||
self_mlx = mx.array(self_np)
|
||||
mask_mlx = mx.array(mask_np)
|
||||
src_mlx = mx.array(src_np)
|
||||
|
||||
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
|
||||
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||
|
||||
# Correctness check once per configuration
|
||||
mx_out = mx.array(self_np)
|
||||
mx_out[mask_mlx] = src_mlx
|
||||
mx.eval(mx_out)
|
||||
torch_out = self_torch.clone()
|
||||
torch_out.masked_scatter_(mask_torch, src_torch)
|
||||
|
||||
atol = 5e-3 if np_dtype == np.float16 else 1e-5
|
||||
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
|
||||
raise AssertionError("masked_scatter results diverged between MLX and Torch")
|
||||
|
||||
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
|
||||
|
||||
|
||||
def bench_case(length, density, dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
torch_dtype = getattr(torch, dtype)
|
||||
(
|
||||
self_mlx,
|
||||
mask_mlx,
|
||||
src_mlx,
|
||||
self_torch,
|
||||
mask_torch,
|
||||
src_torch,
|
||||
true_count,
|
||||
) = build_case(length, density, np_dtype, torch_dtype)
|
||||
|
||||
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
|
||||
time_torch = measure(
|
||||
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
|
||||
)
|
||||
|
||||
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
|
||||
bytes_per_gb = float(1024**3)
|
||||
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
|
||||
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
|
||||
|
||||
return time_mlx, time_torch, mlx_gbps, torch_gbps
|
||||
|
||||
|
||||
def plot_density(ax_perf, ax_speedup, density, dtype):
|
||||
mlx_gbps = []
|
||||
torch_gbps = []
|
||||
mlx_times = []
|
||||
torch_times = []
|
||||
|
||||
for length in VECTOR_LENGTHS:
|
||||
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
|
||||
mlx_gbps.append(gbps_mlx)
|
||||
torch_gbps.append(gbps_torch)
|
||||
mlx_times.append(t_mlx)
|
||||
torch_times.append(t_torch)
|
||||
|
||||
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
|
||||
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
|
||||
ax_perf.set_xscale("log", base=2)
|
||||
ax_perf.set_xticks(VECTOR_LENGTHS)
|
||||
formatter = FuncFormatter(_power_of_two_formatter)
|
||||
ax_perf.xaxis.set_major_formatter(formatter)
|
||||
ax_perf.set_title(f"density={density:.2f}")
|
||||
ax_perf.set_ylabel("GB/s")
|
||||
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||
ax_perf.legend()
|
||||
|
||||
speedup = np.array(torch_times) / np.array(mlx_times)
|
||||
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
|
||||
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
|
||||
ax_speedup.set_xscale("log", base=2)
|
||||
ax_speedup.set_xticks(VECTOR_LENGTHS)
|
||||
ax_speedup.xaxis.set_major_formatter(formatter)
|
||||
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
|
||||
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||
|
||||
|
||||
def main():
|
||||
for dtype in D_TYPES:
|
||||
fig, axs = plt.subplots(
|
||||
len(MASK_DENSITIES),
|
||||
2,
|
||||
figsize=(10, 12),
|
||||
layout="constrained",
|
||||
sharex=True,
|
||||
)
|
||||
|
||||
for i, density in enumerate(MASK_DENSITIES):
|
||||
plot_density(axs[i][0], axs[i][1], density, dtype)
|
||||
axs[i][0].set_xlabel("vector length")
|
||||
axs[i][1].set_xlabel("vector length")
|
||||
|
||||
fig.suptitle(
|
||||
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
|
||||
)
|
||||
output_path = os.path.join(
|
||||
RESULTS_DIR,
|
||||
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
|
||||
)
|
||||
fig.savefig(output_path)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user