Update headdim 128 tuning

This commit is contained in:
Jagrit Digani
2024-11-20 15:41:34 -08:00
parent 791f50d9f3
commit d571366250
4 changed files with 13 additions and 6 deletions

View File

@@ -144,7 +144,7 @@ if __name__ == "__main__":
transposes = (False,)
# fmt: off
shapes = (
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
@@ -162,9 +162,16 @@ if __name__ == "__main__":
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes + shapes_80
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")