mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-05 00:31:15 +08:00
Update sdpa_benchmarks
This commit is contained in:
parent
f1d87a2d3e
commit
c9ab537b9a
@ -145,6 +145,8 @@ if __name__ == "__main__":
|
||||
"float32",
|
||||
)
|
||||
transposes = (False,)
|
||||
|
||||
# clang-format off
|
||||
shapes = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
(1, 32, 32, 64, 32, 32),
|
||||
@ -156,6 +158,7 @@ if __name__ == "__main__":
|
||||
(1, 2048, 2048, 64, 32, 32),
|
||||
(1, 4096, 4096, 64, 32, 32),
|
||||
)
|
||||
# clang-format on
|
||||
|
||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user