mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
make models/phi3.py and models/phi3small.py compatible with mypy (#833)
This commit is contained in:
parent
fda41545a6
commit
6da07fb1b0
@ -1,10 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs, KVCache
|
||||||
from .su_rope import SuScaledRotaryEmbedding
|
from .su_rope import SuScaledRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@ -17,10 +17,10 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: Optional[int] = None
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
|
||||||
max_position_embeddings: int = 131072
|
max_position_embeddings: int = 131072
|
||||||
original_max_position_embeddings: int = 4096
|
original_max_position_embeddings: int = 4096
|
||||||
|
|
||||||
@ -46,6 +46,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
assert args.num_key_value_heads is not None
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
self.num_hidden_layers = args.num_hidden_layers
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
|
||||||
@ -70,6 +71,7 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||||
|
assert isinstance(args.rope_scaling["factor"], float)
|
||||||
rope_scale = 1 / args.rope_scaling["factor"]
|
rope_scale = 1 / args.rope_scaling["factor"]
|
||||||
self.rope = nn.RoPE(
|
self.rope = nn.RoPE(
|
||||||
head_dim,
|
head_dim,
|
||||||
@ -82,7 +84,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -141,7 +143,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs, KVCache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
layer_norm_epsilon: float
|
layer_norm_epsilon: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: Optional[int] = None
|
||||||
mup_attn_multiplier: float = 1.0
|
mup_attn_multiplier: float = 1.0
|
||||||
mup_use_scaling: bool = True
|
mup_use_scaling: bool = True
|
||||||
mup_embedding_multiplier: float = 10.0
|
mup_embedding_multiplier: float = 10.0
|
||||||
mup_width_multiplier: float = 8.0
|
mup_width_multiplier: float = 8.0
|
||||||
rope_embedding_base: float = 1000000
|
rope_embedding_base: float = 1000000
|
||||||
rope_position_scale: float = 1.0
|
rope_position_scale: float = 1.0
|
||||||
blocksparse_block_size: int = (64,)
|
blocksparse_block_size: Tuple[int] = (64,)
|
||||||
blocksparse_num_local_blocks: int = 16
|
blocksparse_num_local_blocks: int = 16
|
||||||
blocksparse_vert_stride: int = 8
|
blocksparse_vert_stride: int = 8
|
||||||
|
|
||||||
@ -58,6 +59,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
assert args.num_key_value_heads is not None
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
self.n_q_per_kv = n_heads // n_kv_heads
|
self.n_q_per_kv = n_heads // n_kv_heads
|
||||||
|
|
||||||
@ -157,7 +159,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -226,7 +228,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
|
Loading…
Reference in New Issue
Block a user