mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-11 10:01:14 +08:00
Update sdpa_benchmarks
This commit is contained in:
parent
c9ab537b9a
commit
0c22440c75
@ -140,25 +140,22 @@ def get_gflop_count(B, M, N, K):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
|
|
||||||
dtypes = (
|
dtypes = ("float16", "float32")
|
||||||
"float16",
|
|
||||||
"float32",
|
|
||||||
)
|
|
||||||
transposes = (False,)
|
transposes = (False,)
|
||||||
|
|
||||||
# clang-format off
|
# fmt: off
|
||||||
shapes = (
|
shapes = (
|
||||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
(1, 32, 32, 64, 32, 32),
|
( 1, 32, 32, 64, 32, 32),
|
||||||
(1, 64, 64, 64, 32, 32),
|
( 1, 64, 64, 64, 32, 32),
|
||||||
(1, 128, 128, 64, 32, 32),
|
( 1, 128, 128, 64, 32, 32),
|
||||||
(1, 256, 256, 64, 32, 32),
|
( 1, 256, 256, 64, 32, 32),
|
||||||
(1, 512, 512, 64, 32, 32),
|
( 1, 512, 512, 64, 32, 32),
|
||||||
(1, 1024, 1024, 64, 32, 32),
|
( 1, 1024, 1024, 64, 32, 32),
|
||||||
(1, 2048, 2048, 64, 32, 32),
|
( 1, 2048, 2048, 64, 32, 32),
|
||||||
(1, 4096, 4096, 64, 32, 32),
|
( 1, 4096, 4096, 64, 32, 32),
|
||||||
)
|
)
|
||||||
# clang-format on
|
# fmt: on
|
||||||
|
|
||||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user