mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
Update to StableLM code (#514)
* StableLM now part of Transformers as stablelm rather than stablelm_epoch; changed config to match new changes * removing old file * reference new stablelm
This commit is contained in:
parent
3acc1ec84e
commit
261f1280f6
@ -18,9 +18,9 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
num_hidden_layers: int
|
num_hidden_layers: int
|
||||||
num_key_value_heads: int
|
num_key_value_heads: int
|
||||||
rope_pct: float
|
partial_rotary_factor: float
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
norm_eps: float
|
layer_norm_eps: float
|
||||||
rope_theta: float
|
rope_theta: float
|
||||||
use_qkv_bias: bool
|
use_qkv_bias: bool
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ class Attention(nn.Module):
|
|||||||
self.num_key_value_heads = config.num_key_value_heads
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
self.repeats = self.num_heads // self.num_key_value_heads
|
self.repeats = self.num_heads // self.num_key_value_heads
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.rope_pct = config.rope_pct
|
self.partial_rotary_factor = config.partial_rotary_factor
|
||||||
|
|
||||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -61,7 +61,7 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.rope = nn.RoPE(
|
self.rope = nn.RoPE(
|
||||||
int(self.rope_pct * self.head_dim),
|
int(self.partial_rotary_factor * self.head_dim),
|
||||||
traditional=False,
|
traditional=False,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
)
|
)
|
||||||
@ -128,11 +128,11 @@ class DecoderLayer(nn.Module):
|
|||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = Attention(config=config)
|
self.self_attn = Attention(config=config)
|
||||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_eps)
|
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_eps)
|
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.post_attention_layernorm = LayerNorm(
|
self.post_attention_layernorm = LayerNorm(
|
||||||
config.hidden_size, eps=config.norm_eps
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x, mask, cache):
|
def __call__(self, x, mask, cache):
|
||||||
@ -148,7 +148,7 @@ class StableLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||||
self.norm = LayerNorm(config.hidden_size, eps=config.norm_eps)
|
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def __call__(self, x, mask, cache):
|
def __call__(self, x, mask, cache):
|
||||||
x = self.embed_tokens(x)
|
x = self.embed_tokens(x)
|
@ -29,7 +29,7 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
|||||||
"llama",
|
"llama",
|
||||||
"phi",
|
"phi",
|
||||||
"mixtral",
|
"mixtral",
|
||||||
"stablelm_epoch",
|
"stablelm",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"gemma",
|
"gemma",
|
||||||
]:
|
]:
|
||||||
|
Loading…
Reference in New Issue
Block a user