mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Adding optional bias param to MultiHeadAttention (#104)
* Adding optional param to * Run style-checker
This commit is contained in:
parent
89b90dcfec
commit
ac6dc5d3eb
@ -43,6 +43,7 @@ class MultiHeadAttention(Module):
|
|||||||
value_input_dims: Optional[int] = None,
|
value_input_dims: Optional[int] = None,
|
||||||
value_dims: Optional[int] = None,
|
value_dims: Optional[int] = None,
|
||||||
value_output_dims: Optional[int] = None,
|
value_output_dims: Optional[int] = None,
|
||||||
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -58,10 +59,10 @@ class MultiHeadAttention(Module):
|
|||||||
value_output_dims = value_output_dims or dims
|
value_output_dims = value_output_dims or dims
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.query_proj = Linear(query_input_dims, dims, False)
|
self.query_proj = Linear(query_input_dims, dims, bias=bias)
|
||||||
self.key_proj = Linear(key_input_dims, dims, False)
|
self.key_proj = Linear(key_input_dims, dims, bias=bias)
|
||||||
self.value_proj = Linear(value_input_dims, value_dims, False)
|
self.value_proj = Linear(value_input_dims, value_dims, bias=bias)
|
||||||
self.out_proj = Linear(value_dims, value_output_dims, False)
|
self.out_proj = Linear(value_dims, value_output_dims, bias=bias)
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None):
|
def __call__(self, queries, keys, values, mask=None):
|
||||||
queries = self.query_proj(queries)
|
queries = self.query_proj(queries)
|
||||||
|
Loading…
Reference in New Issue
Block a user