mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Fixed the return type for the __call__ method in Attention
This commit is contained in:
@@ -67,7 +67,7 @@ class Attention(nn.Module):
|
|||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
Reference in New Issue
Block a user