diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index b30456fd..e4a8cc7d 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache from .su_rope import SuScaledRotaryEmbedding @@ -17,10 +17,10 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 10000 rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None + rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None max_position_embeddings: int = 131072 original_max_position_embeddings: int = 4096 @@ -46,6 +46,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.num_hidden_layers = args.num_hidden_layers @@ -70,6 +71,7 @@ class Attention(nn.Module): ) else: if args.rope_scaling and args.rope_scaling["type"] == "linear": + assert isinstance(args.rope_scaling["factor"], float) rope_scale = 1 / args.rope_scaling["factor"] self.rope = nn.RoPE( head_dim, @@ -82,7 +84,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -141,7 +143,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index f3644a56..e0f2d856 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -1,3 +1,4 @@ +import math from dataclasses import dataclass from functools import partial from typing import Dict, Optional, Tuple, Union @@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache @dataclass @@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int layer_norm_epsilon: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None mup_attn_multiplier: float = 1.0 mup_use_scaling: bool = True mup_embedding_multiplier: float = 10.0 mup_width_multiplier: float = 8.0 rope_embedding_base: float = 1000000 rope_position_scale: float = 1.0 - blocksparse_block_size: int = (64,) + blocksparse_block_size: Tuple[int] = (64,) blocksparse_num_local_blocks: int = 16 blocksparse_vert_stride: int = 8 @@ -58,6 +59,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_q_per_kv = n_heads // n_kv_heads @@ -157,7 +159,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -226,7 +228,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r