From a7598e9456c6455a07ff4905712c2ea3cfcd52db Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Fri, 14 Jun 2024 09:44:50 -0700 Subject: [PATCH] Fix mypy errors with models/{qwen2,qwen2_moe,startcoder2}.py (#835) * Fix starcoder.py * Fix qwen2 * Remvoe unnecessary assert not None --- llms/mlx_lm/models/qwen2.py | 9 +++++---- llms/mlx_lm/models/qwen2_moe.py | 9 +++++---- llms/mlx_lm/models/starcoder2.py | 6 +++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index b928de09..fab09003 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache @dataclass @@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -41,6 +41,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads head_dim = args.hidden_size // n_heads @@ -67,7 +68,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -121,7 +122,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index bba02da0..57f154a0 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache from .switch_layers import SwitchGLU @@ -22,7 +22,7 @@ class ModelArgs(BaseModelArgs): shared_expert_intermediate_size: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -47,6 +47,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads head_dim = args.hidden_size // n_heads @@ -67,7 +68,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -159,7 +160,7 @@ class Qwen2MoeDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index ca06bdb1..7b058d8f 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache @dataclass @@ -43,7 +43,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -98,7 +98,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r