diff --git a/benchmarks/python/sdpa_bench.py b/benchmarks/python/sdpa_bench.py index 56042e2d2..c26c9ce04 100644 --- a/benchmarks/python/sdpa_bench.py +++ b/benchmarks/python/sdpa_bench.py @@ -140,23 +140,32 @@ def get_gflop_count(B, M, N, K): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run gemm benchmarks") - dtypes = ("float16", "float32") + dtypes = ("float16", "float32")[:1] transposes = (False,) # fmt: off shapes = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) - ( 1, 32, 32, 64, 32, 32), - ( 1, 64, 64, 64, 32, 32), - ( 1, 128, 128, 64, 32, 32), - ( 1, 256, 256, 64, 32, 32), - ( 1, 512, 512, 64, 32, 32), - ( 1, 1024, 1024, 64, 32, 32), - ( 1, 2048, 2048, 64, 32, 32), - ( 1, 4096, 4096, 64, 32, 32), + ( 1, 32, 32, 64, 32, 32), + ( 1, 64, 64, 64, 32, 32), + ( 1, 128, 128, 64, 32, 32), + ( 1, 256, 256, 64, 32, 32), + ( 1, 512, 512, 64, 32, 32), + ( 1, 1024, 1024, 64, 32, 32), + ( 1, 2048, 2048, 64, 32, 32), + ( 1, 4096, 4096, 64, 32, 32), + ) + + shapes_80 = ( + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) + ( 1, 1024, 1024, 80, 32, 32), + ( 1, 2048, 2048, 80, 32, 32), + ( 1, 4096, 4096, 80, 32, 32), ) # fmt: on + shapes = shapes + shapes_80 + print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%") for dtype in dtypes: diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 8cbf0d015..47f17745c 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -644,7 +644,7 @@ array scaled_dot_product_attention( const bool sdpa_vector_supported_head_dim = query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128; const bool sdpa_full_supported_head_dim = - query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128; + query_head_dim == 64 || query_head_dim == 80; const bool supports_sdpa_full = query_sequence_length >= threshold && !mask.has_value() && sdpa_full_supported_head_dim &&