mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fixed the return type for the __call__ method in Attention (#190)
This commit is contained in:
parent
2bd20ef0e0
commit
a516f4635d
@ -67,7 +67,7 @@ class Attention(nn.Module):
|
|||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user