Fixed the return type for the __call__ method in Attention (#190)

This commit is contained in:
Sushant 2023-12-26 23:02:43 +05:30 committed by GitHub
parent 2bd20ef0e0
commit a516f4635d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -67,7 +67,7 @@ class Attention(nn.Module):
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array: ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
B, L, D = x.shape B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x) queries, keys, values = self.wq(x), self.wk(x), self.wv(x)