Faster complex matmul (#2571)

This commit is contained in:
Daniel Yeh
2025-10-03 08:33:15 +02:00
committed by GitHub
parent 287c63a093
commit 22a5da76c8
20 changed files with 623 additions and 73 deletions

View File

@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16")
dtypes = ("float32", "float16", "complex64")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
@@ -187,7 +185,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True):
for dtype in ("float32", "float16"):
for dtype in ("float32", "float16", "complex64"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
)
)
plt.close(fig)