Fix incorrect type annotation (#720)

A `Tuple` is missing in this type annotation.
This commit is contained in:
Kevin Wang 2024-04-24 18:52:43 -04:00 committed by GitHub
parent abcd891851
commit 8a265f0d54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -105,7 +105,7 @@ class MultiHeadAttention(nn.Module):
values: mx.array, values: mx.array,
mask: Optional[mx.array], mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> [mx.array, Tuple[mx.array, mx.array]]: ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries) queries = self.query_proj(queries)
keys = self.key_proj(keys) keys = self.key_proj(keys)
values = self.value_proj(values) values = self.value_proj(values)