Fix mypy errors with models/{qwen2,qwen2_moe,startcoder2}.py (#835)

* Fix starcoder.py

* Fix qwen2

* Remvoe unnecessary assert not None
This commit is contained in:
Yi Wang 2024-06-14 09:44:50 -07:00 committed by GitHub
parent d8b073e3a7
commit a7598e9456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 11 deletions

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
@dataclass
@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@ -41,6 +41,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
@ -67,7 +68,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -121,7 +122,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
from .switch_layers import SwitchGLU
@ -22,7 +22,7 @@ class ModelArgs(BaseModelArgs):
shared_expert_intermediate_size: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@ -47,6 +47,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
@ -67,7 +68,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -159,7 +160,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
@dataclass
@ -43,7 +43,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r