mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Update to StableLM code (#514)
* StableLM now part of Transformers as stablelm rather than stablelm_epoch; changed config to match new changes * removing old file * reference new stablelm
This commit is contained in:
parent
3acc1ec84e
commit
261f1280f6
@ -18,9 +18,9 @@ class ModelArgs(BaseModelArgs):
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
rope_pct: float
|
||||
partial_rotary_factor: float
|
||||
intermediate_size: int
|
||||
norm_eps: float
|
||||
layer_norm_eps: float
|
||||
rope_theta: float
|
||||
use_qkv_bias: bool
|
||||
|
||||
@ -35,7 +35,7 @@ class Attention(nn.Module):
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.repeats = self.num_heads // self.num_key_value_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
self.rope_pct = config.rope_pct
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
@ -61,7 +61,7 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
int(self.rope_pct * self.head_dim),
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
@ -128,11 +128,11 @@ class DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config=config)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
config.hidden_size, eps=config.norm_eps
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
@ -148,7 +148,7 @@ class StableLM(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.norm = LayerNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
@ -29,7 +29,7 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
||||
"llama",
|
||||
"phi",
|
||||
"mixtral",
|
||||
"stablelm_epoch",
|
||||
"stablelm",
|
||||
"qwen2",
|
||||
"gemma",
|
||||
]:
|
||||
|
Loading…
Reference in New Issue
Block a user