Fixed the return type for the __call__ method in Attention (#190)

This commit is contained in:
Sushant
2023-12-26 23:02:43 +05:30
committed by GitHub
parent c63517bbb3
commit 89757a7f8a

View File

@@ -67,7 +67,7 @@ class Attention(nn.Module):
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)