Fix mypy errors with models/{qwen2,qwen2_moe,startcoder2}.py (#835)

* Fix starcoder.py

* Fix qwen2

* Remvoe unnecessary assert not None
This commit is contained in:
Yi Wang 2024-06-14 09:44:50 -07:00 committed by GitHub
parent d8b073e3a7
commit a7598e9456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 11 deletions

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
@dataclass @dataclass
@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000 rope_theta: float = 1000000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@ -41,6 +41,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
@ -67,7 +68,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@ -121,7 +122,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -22,7 +22,7 @@ class ModelArgs(BaseModelArgs):
shared_expert_intermediate_size: int shared_expert_intermediate_size: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000 rope_theta: float = 1000000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@ -47,6 +47,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
@ -67,7 +68,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@ -159,7 +160,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
@dataclass @dataclass
@ -43,7 +43,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r