mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-19 09:31:13 +08:00
Lazy import + refactor Lora layer addition (#426)
* lazy model import in mlx_lm * change lora loading * fix olmo lora * remove a bunch of unused stuff from plamo * move phixtral to mlx-lm and out of llms/
This commit is contained in:
parent
4576946151
commit
d4666615bb
@ -9,7 +9,8 @@ from mlx.utils import tree_flatten
|
||||
|
||||
from .tuner.lora import LoRALinear
|
||||
from .tuner.trainer import TrainingArgs, evaluate, train
|
||||
from .utils import LORA_SUPPORTED_MODELS, generate, load
|
||||
from .tuner.utils import linear_to_lora_layers
|
||||
from .utils import generate, load
|
||||
|
||||
|
||||
def build_parser():
|
||||
@ -169,19 +170,10 @@ if __name__ == "__main__":
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
if model.__class__ not in LORA_SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {model.__class__} not supported. "
|
||||
f"Supported models: {LORA_SUPPORTED_MODELS}"
|
||||
)
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
# Freeze all layers
|
||||
model.freeze()
|
||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
if hasattr(l, "block_sparse_moe"):
|
||||
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
|
||||
# Convert linear layers to lora layers and unfreeze in the process
|
||||
linear_to_lora_layers(model, args.lora_layers)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
|
@ -9,6 +9,7 @@ from .base import BaseModelArgs
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
@ -18,7 +19,6 @@ class ModelArgs(BaseModelArgs):
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@ -190,6 +190,7 @@ class LlamaModel(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
|
@ -21,9 +21,9 @@ class ModelArgs(BaseModelArgs):
|
||||
num_local_experts: int = 8
|
||||
rms_norm_eps: float = 1e-5
|
||||
vocab_size: int
|
||||
model_type: str
|
||||
rope_theta: float = 1e6
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@ -252,6 +252,7 @@ class MixtralModel(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = MixtralModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
|
@ -7,18 +7,25 @@ import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
try:
|
||||
import hf_olmo
|
||||
except ImportError:
|
||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||
exit(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
d_model: int
|
||||
n_layers: int
|
||||
mlp_hidden_size: int
|
||||
n_heads: int
|
||||
vocab_size: int
|
||||
embedding_size: int
|
||||
model_type: str
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
mlp_ratio: int = 4
|
||||
weight_tying: bool = False
|
||||
|
||||
@ -162,11 +169,7 @@ class OlmoModel(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
try:
|
||||
import hf_olmo
|
||||
except ImportError:
|
||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||
exit(1)
|
||||
self.model_type = args.model_type
|
||||
self.model = OlmoModel(args)
|
||||
|
||||
def __call__(
|
||||
|
@ -10,6 +10,7 @@ from .base import BaseModelArgs
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
max_position_embeddings: int = 2048
|
||||
vocab_size: int = 51200
|
||||
hidden_size: int = 2560
|
||||
@ -163,6 +164,7 @@ class PhiModel(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.model = PhiModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
@ -8,6 +8,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
@ -15,6 +16,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
model_type: str
|
||||
max_sequence_length: int = 2048
|
||||
num_vocab: int = 51200
|
||||
model_dim: int = 2560
|
||||
@ -110,30 +112,37 @@ class MOE(nn.Module):
|
||||
self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)]
|
||||
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.gate(x)
|
||||
if ne < self.num_experts:
|
||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||
else:
|
||||
inds = mx.broadcast_to(mx.arange(ne), gates.shape)
|
||||
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1))[:, :ne]
|
||||
scores = mx.softmax(
|
||||
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||
axis=-1,
|
||||
).astype(gates.dtype)
|
||||
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
yc = mx.concatenate(y)
|
||||
if self.training:
|
||||
ys = []
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
for e, expert in enumerate(self.mlp):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
|
||||
return yc.reshape(orig_shape)
|
||||
y = (y * scores[..., None]).sum(axis=1)
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
@ -190,6 +199,7 @@ class OutputHead(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.transformer = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
|
||||
@ -206,57 +216,3 @@ class Model(nn.Module):
|
||||
|
||||
y, cache = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits)
|
||||
yield y
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str):
|
||||
# If the path exists, it will try to load model form it
|
||||
# otherwise download and cache from the hf_repo and cache
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
quantization = config.get("quantization", None)
|
||||
model_args = ModelArgs.from_dict(config)
|
||||
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
model = Model(model_args)
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
)
|
||||
return model, tokenizer
|
@ -1,126 +1,25 @@
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
class DecoderInput(NamedTuple):
|
||||
hidden_states: mx.array
|
||||
position_ids: mx.array
|
||||
attention_mask: Optional[mx.array] = None
|
||||
past_key_values: Optional[List[mx.array]] = None
|
||||
output_hidden_states: Optional[bool] = False
|
||||
output_attentions: Optional[bool] = False
|
||||
use_cache: Optional[bool] = False
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
|
||||
class DecoderOutput(NamedTuple):
|
||||
hidden_states: mx.array
|
||||
all_hidden_states: Optional[Tuple[mx.array, ...]]
|
||||
all_self_attns: Optional[Tuple[mx.array, ...]]
|
||||
next_decoder_cache: Optional[Tuple[mx.array, ...]]
|
||||
|
||||
|
||||
class ModelArgs(PretrainedConfig): # type: ignore
|
||||
model_type: str = "plamo"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 32000,
|
||||
hidden_size: int = 4096,
|
||||
intermediate_size: int = 13312,
|
||||
num_hidden_layers: int = 32,
|
||||
num_attention_heads: int = 32,
|
||||
max_position_embeddings: int = 2048,
|
||||
initializer_range: float = 0.02,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
use_cache: bool = True,
|
||||
tokenizer_class: str = "PlamoTokenizer",
|
||||
pad_token_id: Optional[int] = None,
|
||||
bos_token_id: int = 1,
|
||||
eos_token_id: int = 2,
|
||||
n_shared_head: int = 8,
|
||||
tie_word_embeddings: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.n_shared_head = n_shared_head
|
||||
|
||||
super().__init__(
|
||||
tokenizer_class=tokenizer_class,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class RotaryEmbedding:
|
||||
def __init__(
|
||||
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.inv_freq = 1.0 / mx.power(
|
||||
self.base, mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim
|
||||
)
|
||||
self.cos_cached = mx.zeros((1, 1, max_position_embeddings, dim))
|
||||
self.sin_cached = mx.zeros((1, 1, max_position_embeddings, dim))
|
||||
self._set_cos_sin_cache(max_position_embeddings)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len: int) -> None:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = mx.arange(self.max_seq_len_cached) # type: ignore
|
||||
|
||||
freqs = mx.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = mx.concatenate((freqs, freqs), axis=-1)
|
||||
self.cos_cached = emb.cos()[None, None, :, :]
|
||||
self.sin_cached = emb.sin()[None, None, :, :]
|
||||
|
||||
def __call__(self, x: mx.array, seq_len: int) -> Tuple[mx.array, mx.array]:
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
|
||||
self.sin_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _rotate_half(x: mx.array) -> mx.array:
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return mx.concatenate((-x2, x1), axis=-1)
|
||||
|
||||
|
||||
def _rotary_pos_emb(
|
||||
x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array
|
||||
) -> mx.array:
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = mx.squeeze(cos, (0, 1)) # [seq_len, dim]
|
||||
sin = mx.squeeze(sin, (0, 1)) # [seq_len, dim]
|
||||
cos = cos[position_ids][:, None] # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids][:, None] # [bs, 1, seq_len, dim]
|
||||
x_embed = (x * cos) + (_rotate_half(x) * sin)
|
||||
return x_embed
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
n_shared_head: int = (8,)
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@ -143,7 +42,6 @@ class Attention(nn.Module):
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
head_dim = self.hidden_size // config.num_attention_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
self.q_num_heads = config.num_attention_heads
|
||||
self.qk_dim = self.v_dim = head_dim
|
||||
@ -165,15 +63,17 @@ class Attention(nn.Module):
|
||||
self.o_proj = nn.Linear(
|
||||
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||
)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.qk_dim, max_position_embeddings=self.max_position_embeddings
|
||||
self.rotary_emb = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=config.rope_traditional,
|
||||
base=config.rope_theta,
|
||||
scale=1.0,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
position_ids: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
bsz, q_len, _ = hidden_states.shape
|
||||
@ -204,13 +104,11 @@ class Attention(nn.Module):
|
||||
key_states = _expand_kv(key_states)
|
||||
value_states = _expand_kv(value_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
kv_seq_len = 0
|
||||
if cache is not None:
|
||||
kv_seq_len += cache[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
assert position_ids is not None
|
||||
query_states = _rotary_pos_emb(query_states, cos, sin, position_ids)
|
||||
key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
|
||||
query_states = self.rotary_emb(query_states, offset=kv_seq_len)
|
||||
key_states = self.rotary_emb(key_states, offset=kv_seq_len)
|
||||
|
||||
if cache is not None:
|
||||
# reuse k, v, self_attention
|
||||
@ -235,10 +133,9 @@ class MLP(nn.Module):
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = nn.silu
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
|
||||
|
||||
|
||||
class PlamoDecoderLayer(nn.Module):
|
||||
@ -254,7 +151,6 @@ class PlamoDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
position_ids: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[Any, ...]:
|
||||
# from LlamaDecoder
|
||||
@ -266,18 +162,14 @@ class PlamoDecoderLayer(nn.Module):
|
||||
hidden_states_sa, cache = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
|
||||
# Residual ("Parallel Layers" is used here, which is different from the normal residual connection)
|
||||
# See "GPT-NeoX-20B: An Open-Source Autoregressive Language Model" for Parallel Layers
|
||||
hidden_states = residual + hidden_states_sa + hidden_states_mlp
|
||||
|
||||
return hidden_states, cache # type: ignore
|
||||
return hidden_states, cache
|
||||
|
||||
|
||||
class PlamoDecoder(nn.Module):
|
||||
@ -289,24 +181,14 @@ class PlamoDecoder(nn.Module):
|
||||
|
||||
|
||||
class PlamoModel(nn.Module):
|
||||
config_class = ModelArgs
|
||||
_no_split_modules: List[str]
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PlamoDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||||
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -326,10 +208,9 @@ class PlamoModel(nn.Module):
|
||||
else:
|
||||
if cache[0] is not None:
|
||||
past_key_values_length = cache[0][0].shape[2]
|
||||
position_ids = _create_position_ids(h.shape[1], past_key_values_length)
|
||||
|
||||
for e, layer in enumerate(self.layers.layers):
|
||||
h, c = layer(h, mask, position_ids, cache[e])
|
||||
h, c = layer(h, mask, cache[e])
|
||||
if cache is not None:
|
||||
cache[e] = c
|
||||
else:
|
||||
@ -338,22 +219,13 @@ class PlamoModel(nn.Module):
|
||||
return self.norm(h), cache
|
||||
|
||||
|
||||
def _create_position_ids(seq_length: int, past_key_values_length: int = 0) -> mx.array:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = mx.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=mx.int64
|
||||
)
|
||||
position_ids = position_ids[None, ...].reshape(-1, seq_length)
|
||||
|
||||
return position_ids
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
def __init__(self, args: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.model = PlamoModel(config)
|
||||
self.model_type = args.model_type
|
||||
self.model = PlamoModel(args)
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=False
|
||||
args.hidden_size, args.vocab_size, bias=False
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
@ -9,6 +9,7 @@ from .base import BaseModelArgs
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 2048
|
||||
num_attention_heads: int = 16
|
||||
num_hidden_layers: int = 24
|
||||
@ -160,6 +161,7 @@ class QwenModel(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.transformer = QwenModel(config)
|
||||
self.lm_head = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=not config.no_bias
|
||||
|
@ -9,6 +9,7 @@ from .base import BaseModelArgs
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
@ -190,6 +191,7 @@ class Qwen2Model(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2Model(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
|
@ -11,6 +11,7 @@ from .base import BaseModelArgs
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
max_position_embeddings: int
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
num_attention_heads: int
|
||||
@ -169,6 +170,7 @@ class StableLM(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.model = StableLM(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
|
@ -7,6 +7,44 @@ from mlx.utils import tree_unflatten
|
||||
from .lora import LoRALinear
|
||||
|
||||
|
||||
def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
||||
"""
|
||||
Convert some of the models linear layers to lora layers.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The neural network model.
|
||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
||||
starting from the last layer.
|
||||
"""
|
||||
if model.model_type in [
|
||||
"mistral",
|
||||
"llama",
|
||||
"phi",
|
||||
"mixtral",
|
||||
"stablelm_epoch",
|
||||
"qwen2",
|
||||
]:
|
||||
for l in model.model.layers[len(model.model.layers) - num_lora_layers :]:
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
if hasattr(l, "block_sparse_moe"):
|
||||
l.block_sparse_moe.gate = LoRALinear.from_linear(
|
||||
l.block_sparse_moe.gate
|
||||
)
|
||||
elif model.model_type == "olmo":
|
||||
for l in model.model.transformer.blocks[
|
||||
len(model.model.transformer.blocks) - num_lora_layers :
|
||||
]:
|
||||
l.att_proj = LoRALinear.from_linear(l.att_proj)
|
||||
elif model.model_type == "phi-msft":
|
||||
for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]:
|
||||
l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv)
|
||||
l.moe.gate = LoRALinear.from_linear(l.moe.gate)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Lora does not support {model.model_type}")
|
||||
|
||||
|
||||
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module:
|
||||
"""
|
||||
Apply LoRA layers to the model.
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@ -12,28 +13,14 @@ from huggingface_hub import snapshot_download
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import llama, mixtral, olmo, phi2, plamo, qwen, qwen2, stablelm_epoch
|
||||
from .tuner.utils import apply_lora_layers
|
||||
|
||||
# Constants
|
||||
MODEL_MAPPING = {
|
||||
"llama": llama,
|
||||
"mistral": llama, # mistral is compatible with llama
|
||||
"mixtral": mixtral,
|
||||
"phi": phi2,
|
||||
"stablelm_epoch": stablelm_epoch,
|
||||
"qwen": qwen,
|
||||
"plamo": plamo,
|
||||
"olmo": olmo,
|
||||
"qwen2": qwen2,
|
||||
MODEL_REMAPPING = {
|
||||
"mistral": "llama", # mistral is compatible with llama
|
||||
"phi-msft": "phixtral",
|
||||
}
|
||||
LORA_SUPPORTED_MODELS = [
|
||||
llama.Model,
|
||||
mixtral.Model,
|
||||
phi2.Model,
|
||||
stablelm_epoch.Model,
|
||||
qwen2.Model,
|
||||
]
|
||||
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
linear_class_predicate = (
|
||||
@ -54,12 +41,14 @@ def _get_classes(config: dict):
|
||||
A tuple containing the Model class and the ModelArgs class.
|
||||
"""
|
||||
model_type = config["model_type"]
|
||||
if model_type not in MODEL_MAPPING:
|
||||
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
||||
try:
|
||||
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
|
||||
except ImportError:
|
||||
msg = f"Model type {model_type} not supported."
|
||||
logging.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
arch = MODEL_MAPPING[model_type]
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
|
@ -1,28 +0,0 @@
|
||||
# Phixtral
|
||||
|
||||
Phixtral is a Mixture of Experts (MoE) architecture inspired by
|
||||
[Mixtral](../mixtral/README.md) but made by combinding fine-tuned versions of
|
||||
Phi-2.[^1][^2]
|
||||
|
||||
### Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Run
|
||||
|
||||
```
|
||||
python generate.py \
|
||||
--model mlabonne/phixtral-4x2_8 \
|
||||
--prompt "write a quick sort in Python"
|
||||
```
|
||||
|
||||
Run `python generate.py --help` to see all the options.
|
||||
|
||||
[^1]: For more details on Phixtral, see the [Hugging Face repo](https://huggingface.co/mlabonne/phixtral-4x2_8).
|
||||
[^2]: For more details on Phi-2 see Microsoft's [blog post](
|
||||
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
|
||||
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2).
|
@ -1,91 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import phixtral
|
||||
import transformers
|
||||
|
||||
|
||||
def generate(
|
||||
model: phixtral.Model,
|
||||
tokenizer: transformers.AutoTokenizer,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temp: float = 0.0,
|
||||
):
|
||||
print("[INFO] Generating with Phixtral...", flush=True)
|
||||
print(prompt, end="", flush=True)
|
||||
prompt = tokenizer(
|
||||
prompt,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)[
|
||||
"input_ids"
|
||||
][0]
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
phixtral.generate(prompt, model, temp),
|
||||
range(max_tokens),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if n == 0:
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
|
||||
tokens.append(token.item())
|
||||
# if (n + 1) % 10 == 0:
|
||||
s = tokenizer.decode(tokens)
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
||||
gen_time = time.time() - tic
|
||||
print("=" * 10)
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="inference script")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="Write a detailed analogy between mathematics and a lighthouse.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
mx.random.seed(args.seed)
|
||||
model, tokenizer = phixtral.load(args.model)
|
||||
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)
|
@ -1,7 +0,0 @@
|
||||
einops
|
||||
hf_transfer
|
||||
huggingface_hub
|
||||
mlx
|
||||
numpy
|
||||
torch
|
||||
transformers>=4.35
|
Loading…
Reference in New Issue
Block a user