mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-04 05:28:11 +08:00 
			
		
		
		
	Fix mypy errors with models/{qwen2,qwen2_moe,startcoder2}.py (#835)
* Fix starcoder.py * Fix qwen2 * Remvoe unnecessary assert not None
This commit is contained in:
		@@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
import mlx.nn as nn
 | 
			
		||||
 | 
			
		||||
from .base import BaseModelArgs
 | 
			
		||||
from .base import BaseModelArgs, KVCache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs):
 | 
			
		||||
    num_attention_heads: int
 | 
			
		||||
    rms_norm_eps: float
 | 
			
		||||
    vocab_size: int
 | 
			
		||||
    num_key_value_heads: int = None
 | 
			
		||||
    num_key_value_heads: Optional[int] = None
 | 
			
		||||
    rope_theta: float = 1000000
 | 
			
		||||
    rope_traditional: bool = False
 | 
			
		||||
    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
 | 
			
		||||
@@ -41,6 +41,7 @@ class Attention(nn.Module):
 | 
			
		||||
 | 
			
		||||
        dim = args.hidden_size
 | 
			
		||||
        self.n_heads = n_heads = args.num_attention_heads
 | 
			
		||||
        assert args.num_key_value_heads is not None
 | 
			
		||||
        self.n_kv_heads = n_kv_heads = args.num_key_value_heads
 | 
			
		||||
 | 
			
		||||
        head_dim = args.hidden_size // n_heads
 | 
			
		||||
@@ -67,7 +68,7 @@ class Attention(nn.Module):
 | 
			
		||||
        self,
 | 
			
		||||
        x: mx.array,
 | 
			
		||||
        mask: Optional[mx.array] = None,
 | 
			
		||||
        cache: Optional[Tuple[mx.array, mx.array]] = None,
 | 
			
		||||
        cache: Optional[KVCache] = None,
 | 
			
		||||
    ) -> mx.array:
 | 
			
		||||
        B, L, D = x.shape
 | 
			
		||||
 | 
			
		||||
@@ -121,7 +122,7 @@ class TransformerBlock(nn.Module):
 | 
			
		||||
        self,
 | 
			
		||||
        x: mx.array,
 | 
			
		||||
        mask: Optional[mx.array] = None,
 | 
			
		||||
        cache: Optional[Tuple[mx.array, mx.array]] = None,
 | 
			
		||||
        cache: Optional[KVCache] = None,
 | 
			
		||||
    ) -> mx.array:
 | 
			
		||||
        r = self.self_attn(self.input_layernorm(x), mask, cache)
 | 
			
		||||
        h = x + r
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
import mlx.nn as nn
 | 
			
		||||
 | 
			
		||||
from .base import BaseModelArgs
 | 
			
		||||
from .base import BaseModelArgs, KVCache
 | 
			
		||||
from .switch_layers import SwitchGLU
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -22,7 +22,7 @@ class ModelArgs(BaseModelArgs):
 | 
			
		||||
    shared_expert_intermediate_size: int
 | 
			
		||||
    rms_norm_eps: float
 | 
			
		||||
    vocab_size: int
 | 
			
		||||
    num_key_value_heads: int = None
 | 
			
		||||
    num_key_value_heads: Optional[int] = None
 | 
			
		||||
    rope_theta: float = 1000000
 | 
			
		||||
    rope_traditional: bool = False
 | 
			
		||||
    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
 | 
			
		||||
@@ -47,6 +47,7 @@ class Attention(nn.Module):
 | 
			
		||||
 | 
			
		||||
        dim = args.hidden_size
 | 
			
		||||
        self.n_heads = n_heads = args.num_attention_heads
 | 
			
		||||
        assert args.num_key_value_heads is not None
 | 
			
		||||
        self.n_kv_heads = n_kv_heads = args.num_key_value_heads
 | 
			
		||||
 | 
			
		||||
        head_dim = args.hidden_size // n_heads
 | 
			
		||||
@@ -67,7 +68,7 @@ class Attention(nn.Module):
 | 
			
		||||
        self,
 | 
			
		||||
        x: mx.array,
 | 
			
		||||
        mask: Optional[mx.array] = None,
 | 
			
		||||
        cache: Optional[Tuple[mx.array, mx.array]] = None,
 | 
			
		||||
        cache: Optional[KVCache] = None,
 | 
			
		||||
    ) -> mx.array:
 | 
			
		||||
        B, L, D = x.shape
 | 
			
		||||
 | 
			
		||||
@@ -159,7 +160,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
 | 
			
		||||
        self,
 | 
			
		||||
        x: mx.array,
 | 
			
		||||
        mask: Optional[mx.array] = None,
 | 
			
		||||
        cache: Optional[Tuple[mx.array, mx.array]] = None,
 | 
			
		||||
        cache: Optional[KVCache] = None,
 | 
			
		||||
    ) -> mx.array:
 | 
			
		||||
        r = self.self_attn(self.input_layernorm(x), mask, cache)
 | 
			
		||||
        h = x + r
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,7 @@ from typing import Optional, Tuple
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
import mlx.nn as nn
 | 
			
		||||
 | 
			
		||||
from .base import BaseModelArgs
 | 
			
		||||
from .base import BaseModelArgs, KVCache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@@ -43,7 +43,7 @@ class Attention(nn.Module):
 | 
			
		||||
        self,
 | 
			
		||||
        x: mx.array,
 | 
			
		||||
        mask: Optional[mx.array] = None,
 | 
			
		||||
        cache: Optional[Tuple[mx.array, mx.array]] = None,
 | 
			
		||||
        cache: Optional[KVCache] = None,
 | 
			
		||||
    ) -> mx.array:
 | 
			
		||||
        B, L, D = x.shape
 | 
			
		||||
 | 
			
		||||
@@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
 | 
			
		||||
        self,
 | 
			
		||||
        x: mx.array,
 | 
			
		||||
        mask: Optional[mx.array] = None,
 | 
			
		||||
        cache: Optional[Tuple[mx.array, mx.array]] = None,
 | 
			
		||||
        cache: Optional[KVCache] = None,
 | 
			
		||||
    ) -> mx.array:
 | 
			
		||||
        r = self.self_attn(self.input_layernorm(x), mask, cache)
 | 
			
		||||
        h = x + r
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user