Update to StableLM code (#514)

* StableLM now part of Transformers as stablelm rather than stablelm_epoch; changed config to match new changes

* removing old file

* reference new stablelm
This commit is contained in:
Ashish 2024-03-01 10:53:38 -07:00 committed by GitHub
parent 3acc1ec84e
commit 261f1280f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 9 deletions

View File

@ -18,9 +18,9 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
num_hidden_layers: int num_hidden_layers: int
num_key_value_heads: int num_key_value_heads: int
rope_pct: float partial_rotary_factor: float
intermediate_size: int intermediate_size: int
norm_eps: float layer_norm_eps: float
rope_theta: float rope_theta: float
use_qkv_bias: bool use_qkv_bias: bool
@ -35,7 +35,7 @@ class Attention(nn.Module):
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.repeats = self.num_heads // self.num_key_value_heads self.repeats = self.num_heads // self.num_key_value_heads
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.rope_pct = config.rope_pct self.partial_rotary_factor = config.partial_rotary_factor
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError( raise ValueError(
@ -61,7 +61,7 @@ class Attention(nn.Module):
) )
self.rope = nn.RoPE( self.rope = nn.RoPE(
int(self.rope_pct * self.head_dim), int(self.partial_rotary_factor * self.head_dim),
traditional=False, traditional=False,
base=self.rope_theta, base=self.rope_theta,
) )
@ -128,11 +128,11 @@ class DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
self.self_attn = Attention(config=config) self.self_attn = Attention(config=config)
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_eps) self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = MLP(config.hidden_size, config.intermediate_size) self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_eps) self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
config.hidden_size, eps=config.norm_eps config.hidden_size, eps=config.layer_norm_eps
) )
def __call__(self, x, mask, cache): def __call__(self, x, mask, cache):
@ -148,7 +148,7 @@ class StableLM(nn.Module):
super().__init__() super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)] self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
self.norm = LayerNorm(config.hidden_size, eps=config.norm_eps) self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def __call__(self, x, mask, cache): def __call__(self, x, mask, cache):
x = self.embed_tokens(x) x = self.embed_tokens(x)

View File

@ -29,7 +29,7 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
"llama", "llama",
"phi", "phi",
"mixtral", "mixtral",
"stablelm_epoch", "stablelm",
"qwen2", "qwen2",
"gemma", "gemma",
]: ]: