Fix type for attention mask

This commit is contained in:
Juarez Bochi 2023-12-18 11:39:17 -05:00
parent 5ae339f6d2
commit 66e1c0f050
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -104,7 +104,7 @@ class MultiHeadAttention(nn.Module):
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: mx.array,
mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)