make models/phi3.py and models/phi3small.py compatible with mypy (#833)

This commit is contained in:
Yi Wang 2024-06-12 06:53:55 -07:00 committed by GitHub
parent fda41545a6
commit 6da07fb1b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 11 deletions

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
from .su_rope import SuScaledRotaryEmbedding from .su_rope import SuScaledRotaryEmbedding
@ -17,10 +17,10 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
max_position_embeddings: int = 131072 max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096 original_max_position_embeddings: int = 4096
@ -46,6 +46,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.num_hidden_layers = args.num_hidden_layers self.num_hidden_layers = args.num_hidden_layers
@ -70,6 +71,7 @@ class Attention(nn.Module):
) )
else: else:
if args.rope_scaling and args.rope_scaling["type"] == "linear": if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"] rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE( self.rope = nn.RoPE(
head_dim, head_dim,
@ -82,7 +84,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@ -141,7 +143,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r

View File

@ -1,3 +1,4 @@
import math
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
@dataclass @dataclass
@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
layer_norm_epsilon: float layer_norm_epsilon: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
mup_attn_multiplier: float = 1.0 mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0 mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0 mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000 rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0 rope_position_scale: float = 1.0
blocksparse_block_size: int = (64,) blocksparse_block_size: Tuple[int] = (64,)
blocksparse_num_local_blocks: int = 16 blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8 blocksparse_vert_stride: int = 8
@ -58,6 +59,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads self.n_q_per_kv = n_heads // n_kv_heads
@ -157,7 +159,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@ -226,7 +228,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r