mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
chore(mlx-lm): update phi2 model args to sync with hf config format. (#311)
* chore(mlx-lm): update phi2 model args to sync with hf config format * chore: fix type hint
This commit is contained in:
parent
7575125d5d
commit
a39b735c3b
@ -61,7 +61,7 @@ text using the given prompt.
|
|||||||
For a full list of options run:
|
For a full list of options run:
|
||||||
|
|
||||||
```
|
```
|
||||||
python -m mlx_lm generate --help
|
python -m mlx_lm.generate --help
|
||||||
```
|
```
|
||||||
|
|
||||||
To quantize a model from the command line run:
|
To quantize a model from the command line run:
|
||||||
|
@ -10,12 +10,20 @@ from .base import BaseModelArgs
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
n_positions: int = 2048
|
max_position_embeddings: int = 2048
|
||||||
vocab_size: int = 51200
|
vocab_size: int = 51200
|
||||||
n_embd: int = 2560
|
hidden_size: int = 2560
|
||||||
n_head: int = 32
|
num_attention_heads: int = 32
|
||||||
n_layer: int = 32
|
num_hidden_layers: int = 32
|
||||||
rotary_dim: int = 32
|
num_key_value_heads: int = 32
|
||||||
|
partial_rotary_factor: float = 0.4
|
||||||
|
intermediate_size: int = 10240
|
||||||
|
layer_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
@ -23,30 +31,66 @@ class LayerNorm(nn.LayerNorm):
|
|||||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class RoPEAttention(nn.Module):
|
class PhiAttention(nn.Module):
|
||||||
def __init__(self, dims: int, n_head: int, rotary_dim: int):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.n_head = n_head
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.repeats = self.num_heads // self.num_key_value_heads
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.partial_rotary_factor = config.partial_rotary_factor
|
||||||
|
|
||||||
self.q_proj = nn.Linear(dims, dims)
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
self.k_proj = nn.Linear(dims, dims)
|
raise ValueError(
|
||||||
self.v_proj = nn.Linear(dims, dims)
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
self.dense = nn.Linear(dims, dims)
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
|
||||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
self.q_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_heads * self.head_dim, bias=True
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||||
|
)
|
||||||
|
self.dense = nn.Linear(
|
||||||
|
self.num_heads * self.head_dim, self.hidden_size, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
int(self.partial_rotary_factor * self.head_dim),
|
||||||
|
traditional=False,
|
||||||
|
base=self.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x, mask=None, cache=None):
|
def __call__(self, x, mask=None, cache=None):
|
||||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
# Extract some shapes
|
# Extract some shapes
|
||||||
n_head = self.n_head
|
|
||||||
B, L, D = queries.shape
|
B, L, D = queries.shape
|
||||||
|
|
||||||
# Prepare the queries, keys and values for the attention computation
|
# Prepare the queries, keys and values for the attention computation
|
||||||
queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
|
queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose(
|
||||||
keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
|
0, 2, 1, 3
|
||||||
values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
|
)
|
||||||
|
keys = keys.reshape(B, L, self.num_key_value_heads, self.head_dim).transpose(
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
values = values.reshape(
|
||||||
|
B, L, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def repeat(a):
|
||||||
|
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||||
|
return a.reshape([B, self.num_heads, L, -1])
|
||||||
|
|
||||||
|
if self.repeats > 1:
|
||||||
|
keys, values = map(repeat, (keys, values))
|
||||||
|
|
||||||
# Add RoPE to the queries and keys and combine them with the cache
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
@ -74,25 +118,23 @@ class RoPEAttention(nn.Module):
|
|||||||
return self.dense(values_hat), (keys, values)
|
return self.dense(values_hat), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class PhiMLP(nn.Module):
|
||||||
def __init__(self, dim, hidden_dim):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
self.act = nn.GELU(approx="precise")
|
self.act = nn.GELU(approx="precise")
|
||||||
|
|
||||||
def __call__(self, x) -> mx.array:
|
def __call__(self, x) -> mx.array:
|
||||||
return self.fc2(self.act(self.fc1(x)))
|
return self.fc2(self.act(self.fc1(x)))
|
||||||
|
|
||||||
|
|
||||||
class ParallelBlock(nn.Module):
|
class PhiDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dims = config.n_embd
|
self.self_attn = PhiAttention(config=config)
|
||||||
mlp_dims = dims * 4
|
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim)
|
self.mlp = PhiMLP(config)
|
||||||
self.input_layernorm = LayerNorm(dims)
|
|
||||||
self.mlp = MLP(dims, mlp_dims)
|
|
||||||
|
|
||||||
def __call__(self, x, mask, cache):
|
def __call__(self, x, mask, cache):
|
||||||
h = self.input_layernorm(x)
|
h = self.input_layernorm(x)
|
||||||
@ -101,12 +143,12 @@ class ParallelBlock(nn.Module):
|
|||||||
return attn_h + ff_h + x, cache
|
return attn_h + ff_h + x, cache
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
class PhiModel(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
self.layers = [ParallelBlock(config) for i in range(config.n_layer)]
|
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||||
self.final_layernorm = LayerNorm(config.n_embd)
|
self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def __call__(self, x, mask, cache):
|
def __call__(self, x, mask, cache):
|
||||||
x = self.embed_tokens(x)
|
x = self.embed_tokens(x)
|
||||||
@ -121,8 +163,8 @@ class Transformer(nn.Module):
|
|||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = Transformer(config)
|
self.model = PhiModel(config)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user