make models/phi3.py and models/phi3small.py compatible with mypy (#833)

This commit is contained in:
Yi Wang 2024-06-12 06:53:55 -07:00 committed by GitHub
parent fda41545a6
commit 6da07fb1b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 11 deletions

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
from .su_rope import SuScaledRotaryEmbedding
@ -17,10 +17,10 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
@ -46,6 +46,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.num_hidden_layers = args.num_hidden_layers
@ -70,6 +71,7 @@ class Attention(nn.Module):
)
else:
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
@ -82,7 +84,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -141,7 +143,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r

View File

@ -1,3 +1,4 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
@dataclass
@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: int = (64,)
blocksparse_block_size: Tuple[int] = (64,)
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@ -58,6 +59,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
@ -157,7 +159,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -226,7 +228,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r