mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Adding optional bias param to MultiHeadAttention (#104)
* Adding optional param to * Run style-checker
This commit is contained in:
parent
89b90dcfec
commit
ac6dc5d3eb
@ -43,6 +43,7 @@ class MultiHeadAttention(Module):
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -58,10 +59,10 @@ class MultiHeadAttention(Module):
|
||||
value_output_dims = value_output_dims or dims
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.query_proj = Linear(query_input_dims, dims, False)
|
||||
self.key_proj = Linear(key_input_dims, dims, False)
|
||||
self.value_proj = Linear(value_input_dims, value_dims, False)
|
||||
self.out_proj = Linear(value_dims, value_output_dims, False)
|
||||
self.query_proj = Linear(query_input_dims, dims, bias=bias)
|
||||
self.key_proj = Linear(key_input_dims, dims, bias=bias)
|
||||
self.value_proj = Linear(value_input_dims, value_dims, bias=bias)
|
||||
self.out_proj = Linear(value_dims, value_output_dims, bias=bias)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.query_proj(queries)
|
||||
|
Loading…
Reference in New Issue
Block a user