From ac6dc5d3ebe230c2f888215a9346c9deaaf302cd Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 14:04:28 -0500 Subject: [PATCH] Adding optional bias param to MultiHeadAttention (#104) * Adding optional param to * Run style-checker --- python/mlx/nn/layers/transformer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 7fffe4291..e24957066 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -43,6 +43,7 @@ class MultiHeadAttention(Module): value_input_dims: Optional[int] = None, value_dims: Optional[int] = None, value_output_dims: Optional[int] = None, + bias: bool = False, ): super().__init__() @@ -58,10 +59,10 @@ class MultiHeadAttention(Module): value_output_dims = value_output_dims or dims self.num_heads = num_heads - self.query_proj = Linear(query_input_dims, dims, False) - self.key_proj = Linear(key_input_dims, dims, False) - self.value_proj = Linear(value_input_dims, value_dims, False) - self.out_proj = Linear(value_dims, value_output_dims, False) + self.query_proj = Linear(query_input_dims, dims, bias=bias) + self.key_proj = Linear(key_input_dims, dims, bias=bias) + self.value_proj = Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = Linear(value_dims, value_output_dims, bias=bias) def __call__(self, queries, keys, values, mask=None): queries = self.query_proj(queries)