diff --git a/llms/speculative_decoding/model.py b/llms/speculative_decoding/model.py index 2300fb6c..c310b943 100644 --- a/llms/speculative_decoding/model.py +++ b/llms/speculative_decoding/model.py @@ -105,7 +105,7 @@ class MultiHeadAttention(nn.Module): values: mx.array, mask: Optional[mx.array], cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> [mx.array, Tuple[mx.array, mx.array]]: + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: queries = self.query_proj(queries) keys = self.key_proj(keys) values = self.value_proj(values)