mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fixed the return type for the __call__ method in Attention (#190)
This commit is contained in:
parent
2bd20ef0e0
commit
a516f4635d
@ -67,7 +67,7 @@ class Attention(nn.Module):
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
Loading…
Reference in New Issue
Block a user