Update sdpa_benchmarks

This commit is contained in:
Jagrit Digani 2024-11-20 15:19:42 -08:00
parent f1d87a2d3e
commit c9ab537b9a

View File

@ -145,6 +145,8 @@ if __name__ == "__main__":
"float32",
)
transposes = (False,)
# clang-format off
shapes = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
(1, 32, 32, 64, 32, 32),
@ -156,6 +158,7 @@ if __name__ == "__main__":
(1, 2048, 2048, 64, 32, 32),
(1, 4096, 4096, 64, 32, 32),
)
# clang-format on
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")