From 0618a400a257dd7d550ff7896992655e38c706e6 Mon Sep 17 00:00:00 2001 From: Sushant Date: Tue, 26 Dec 2023 23:00:01 +0530 Subject: [PATCH] Fixed the return type for the __call__ method in Attention --- llms/llama/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 97ec4101..1b44d650 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -67,7 +67,7 @@ class Attention(nn.Module): x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: B, L, D = x.shape queries, keys, values = self.wq(x), self.wk(x), self.wv(x)