mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
support internlm2 (#797)
* support internlm2 * only attention projections --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
ca7ce60c91
commit
aac98ca6f4
@ -120,6 +120,7 @@ Here are a few examples of Hugging Face models that work with this example:
|
|||||||
- [pfnet/plamo-13b](https://huggingface.co/pfnet/plamo-13b)
|
- [pfnet/plamo-13b](https://huggingface.co/pfnet/plamo-13b)
|
||||||
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
||||||
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
||||||
|
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
||||||
|
|
||||||
Most
|
Most
|
||||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||||
|
@ -12,6 +12,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
|||||||
- Gemma
|
- Gemma
|
||||||
- OLMo
|
- OLMo
|
||||||
- MiniCPM
|
- MiniCPM
|
||||||
|
- InternLM2
|
||||||
|
|
||||||
## Contents
|
## Contents
|
||||||
|
|
||||||
|
198
llms/mlx_lm/models/internlm2.py
Normal file
198
llms/mlx_lm/models/internlm2.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
bias: bool = True
|
||||||
|
num_key_value_heads: int = None
|
||||||
|
rope_theta: float = 10000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
if self.rope_scaling:
|
||||||
|
required_keys = {"factor", "type"}
|
||||||
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
|
if self.rope_scaling["type"] != "linear":
|
||||||
|
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
self.n_kv_groups = n_heads // args.num_key_value_heads
|
||||||
|
|
||||||
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.wqkv = nn.Linear(
|
||||||
|
dim, (n_heads + 2 * n_kv_heads) * head_dim, bias=args.bias
|
||||||
|
)
|
||||||
|
self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias)
|
||||||
|
|
||||||
|
rope_scale = (
|
||||||
|
1 / args.rope_scaling["factor"]
|
||||||
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
qkv_states = self.wqkv(x)
|
||||||
|
qkv_states = qkv_states.reshape(B, L, -1, 2 + self.n_kv_groups, self.head_dim)
|
||||||
|
|
||||||
|
queries = qkv_states[..., : self.n_kv_groups, :]
|
||||||
|
queries = queries.reshape(B, L, -1, self.head_dim)
|
||||||
|
keys = qkv_states[..., -2, :]
|
||||||
|
values = qkv_states[..., -1, :]
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.wo(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Attention(args)
|
||||||
|
self.feed_forward = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
self.attention_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.attention(self.attention_norm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.feed_forward(self.ffn_norm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class InternLM2Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
assert args.vocab_size > 0
|
||||||
|
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.tok_embeddings(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = InternLM2Model(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.tok_embeddings.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.output(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
@ -119,6 +119,8 @@ def linear_to_lora_layers(
|
|||||||
keys = set(["mixer.Wqkv", "moe.gate"])
|
keys = set(["mixer.Wqkv", "moe.gate"])
|
||||||
elif model.model_type == "dbrx":
|
elif model.model_type == "dbrx":
|
||||||
keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"])
|
keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"])
|
||||||
|
elif model.model_type == "internlm2":
|
||||||
|
keys = set(["attention.wqkv", "attention.wo"])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Lora does not support {model.model_type}")
|
raise ValueError(f"Lora does not support {model.model_type}")
|
||||||
|
|
||||||
|
@ -397,6 +397,23 @@ class TestModels(unittest.TestCase):
|
|||||||
len(args.ffn_multipliers),
|
len(args.ffn_multipliers),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_internlm2(self):
|
||||||
|
from mlx_lm.models import internlm2
|
||||||
|
|
||||||
|
args = internlm2.ModelArgs(
|
||||||
|
model_type="internlm2",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10000,
|
||||||
|
)
|
||||||
|
model = internlm2.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user