From a516f4635d048ad3c110dd4af37b7bb11faa38dd Mon Sep 17 00:00:00 2001 From: Sushant Date: Tue, 26 Dec 2023 23:02:43 +0530 Subject: [PATCH] Fixed the return type for the __call__ method in Attention (#190) --- llms/llama/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 97ec4101..1b44d650 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -67,7 +67,7 @@ class Attention(nn.Module): x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: B, L, D = x.shape queries, keys, values = self.wq(x), self.wk(x), self.wv(x)