Correct the type annotation of cache in llama.py (#828)

* Update

* Fix isort
This commit is contained in:
Yi Wang 2024-06-10 15:18:34 -07:00 committed by GitHub
parent bb8227f181
commit a54dfd698e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_additive_causal_mask from .base import BaseModelArgs, KVCache, create_additive_causal_mask
@dataclass @dataclass
@ -73,7 +73,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@ -135,7 +135,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r