From 66e1c0f050f090d5580f131d40ba2ef190dac617 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Mon, 18 Dec 2023 11:39:17 -0500 Subject: [PATCH] Fix type for attention mask --- t5/t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/t5/t5.py b/t5/t5.py index 65c758e4..6a21c791 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -104,7 +104,7 @@ class MultiHeadAttention(nn.Module): queries: mx.array, keys: mx.array, values: mx.array, - mask: mx.array, + mask: Optional[mx.array], cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> [mx.array, Tuple[mx.array, mx.array]]: queries = self.query_proj(queries)