mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix type for attention mask
This commit is contained in:
parent
5ae339f6d2
commit
66e1c0f050
2
t5/t5.py
2
t5/t5.py
@ -104,7 +104,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
queries: mx.array,
|
queries: mx.array,
|
||||||
keys: mx.array,
|
keys: mx.array,
|
||||||
values: mx.array,
|
values: mx.array,
|
||||||
mask: mx.array,
|
mask: Optional[mx.array],
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||||
queries = self.query_proj(queries)
|
queries = self.query_proj(queries)
|
||||||
|
Loading…
Reference in New Issue
Block a user