diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 97ec4101..1b44d650 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -67,7 +67,7 @@ class Attention(nn.Module): x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: B, L, D = x.shape queries, keys, values = self.wq(x), self.wk(x), self.wv(x)