diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 7fffe4291..e24957066 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -43,6 +43,7 @@ class MultiHeadAttention(Module): value_input_dims: Optional[int] = None, value_dims: Optional[int] = None, value_output_dims: Optional[int] = None, + bias: bool = False, ): super().__init__() @@ -58,10 +59,10 @@ class MultiHeadAttention(Module): value_output_dims = value_output_dims or dims self.num_heads = num_heads - self.query_proj = Linear(query_input_dims, dims, False) - self.key_proj = Linear(key_input_dims, dims, False) - self.value_proj = Linear(value_input_dims, value_dims, False) - self.out_proj = Linear(value_dims, value_output_dims, False) + self.query_proj = Linear(query_input_dims, dims, bias=bias) + self.key_proj = Linear(key_input_dims, dims, bias=bias) + self.value_proj = Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = Linear(value_dims, value_output_dims, bias=bias) def __call__(self, queries, keys, values, mask=None): queries = self.query_proj(queries)