Adding optional bias param to MultiHeadAttention (#104)

* Adding optional  param to

* Run style-checker
This commit is contained in:
Joe Barrow 2023-12-09 14:04:28 -05:00 committed by GitHub
parent 89b90dcfec
commit ac6dc5d3eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -43,6 +43,7 @@ class MultiHeadAttention(Module):
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
bias: bool = False,
):
super().__init__()
@ -58,10 +59,10 @@ class MultiHeadAttention(Module):
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = Linear(query_input_dims, dims, False)
self.key_proj = Linear(key_input_dims, dims, False)
self.value_proj = Linear(value_input_dims, value_dims, False)
self.out_proj = Linear(value_dims, value_output_dims, False)
self.query_proj = Linear(query_input_dims, dims, bias=bias)
self.key_proj = Linear(key_input_dims, dims, bias=bias)
self.value_proj = Linear(value_input_dims, value_dims, bias=bias)
self.out_proj = Linear(value_dims, value_output_dims, bias=bias)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)