From d4666615bb33622d5c4cf8815d61a1f5d64671e7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 12 Feb 2024 10:51:02 -0800 Subject: [PATCH] Lazy import + refactor Lora layer addition (#426) * lazy model import in mlx_lm * change lora loading * fix olmo lora * remove a bunch of unused stuff from plamo * move phixtral to mlx-lm and out of llms/ --- llms/mlx_lm/lora.py | 18 +- llms/mlx_lm/models/llama.py | 3 +- llms/mlx_lm/models/mixtral.py | 3 +- llms/mlx_lm/models/olmo.py | 15 +- llms/mlx_lm/models/{phi2.py => phi.py} | 2 + llms/{phixtral => mlx_lm/models}/phixtral.py | 90 +++------ llms/mlx_lm/models/plamo.py | 190 +++---------------- llms/mlx_lm/models/qwen.py | 2 + llms/mlx_lm/models/qwen2.py | 2 + llms/mlx_lm/models/stablelm_epoch.py | 2 + llms/mlx_lm/tuner/utils.py | 38 ++++ llms/mlx_lm/utils.py | 29 +-- llms/phixtral/README.md | 28 --- llms/phixtral/generate.py | 91 --------- llms/phixtral/requirements.txt | 7 - 15 files changed, 127 insertions(+), 393 deletions(-) rename llms/mlx_lm/models/{phi2.py => phi.py} (98%) rename llms/{phixtral => mlx_lm/models}/phixtral.py (73%) delete mode 100644 llms/phixtral/README.md delete mode 100644 llms/phixtral/generate.py delete mode 100644 llms/phixtral/requirements.txt diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 0c723c86..75093080 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -9,7 +9,8 @@ from mlx.utils import tree_flatten from .tuner.lora import LoRALinear from .tuner.trainer import TrainingArgs, evaluate, train -from .utils import LORA_SUPPORTED_MODELS, generate, load +from .tuner.utils import linear_to_lora_layers +from .utils import generate, load def build_parser(): @@ -169,19 +170,10 @@ if __name__ == "__main__": print("Loading pretrained model") model, tokenizer = load(args.model) - if model.__class__ not in LORA_SUPPORTED_MODELS: - raise ValueError( - f"Model {model.__class__} not supported. " - f"Supported models: {LORA_SUPPORTED_MODELS}" - ) - - # Freeze all layers other than LORA linears + # Freeze all layers model.freeze() - for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: - l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) - l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) - if hasattr(l, "block_sparse_moe"): - l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) + # Convert linear layers to lora layers and unfreeze in the process + linear_to_lora_layers(model, args.lora_layers) p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 print(f"Total parameters {p:.3f}M") diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index b61aecaf..f44a94e7 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -9,6 +9,7 @@ from .base import BaseModelArgs @dataclass class ModelArgs(BaseModelArgs): + model_type: str hidden_size: int num_hidden_layers: int intermediate_size: int @@ -18,7 +19,6 @@ class ModelArgs(BaseModelArgs): num_key_value_heads: int = None rope_theta: float = 10000 rope_traditional: bool = False - model_type: str = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None def __post_init__(self): @@ -190,6 +190,7 @@ class LlamaModel(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.model_type = args.model_type self.model = LlamaModel(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 63401de1..fbd4c7a3 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -21,9 +21,9 @@ class ModelArgs(BaseModelArgs): num_local_experts: int = 8 rms_norm_eps: float = 1e-5 vocab_size: int + model_type: str rope_theta: float = 1e6 rope_traditional: bool = False - model_type: str = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None def __post_init__(self): @@ -252,6 +252,7 @@ class MixtralModel(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.model_type = args.model_type self.model = MixtralModel(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 2525b181..0a2c9c0d 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -7,18 +7,25 @@ import mlx.nn as nn from .base import BaseModelArgs +try: + import hf_olmo +except ImportError: + print("To run olmo install ai2-olmo: pip install ai2-olmo") + exit(1) + @dataclass class ModelArgs(BaseModelArgs): + model_type: str d_model: int n_layers: int mlp_hidden_size: int n_heads: int vocab_size: int embedding_size: int + model_type: str rope_theta: float = 10000 rope_traditional: bool = False - model_type: str = None mlp_ratio: int = 4 weight_tying: bool = False @@ -162,11 +169,7 @@ class OlmoModel(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - try: - import hf_olmo - except ImportError: - print("To run olmo install ai2-olmo: pip install ai2-olmo") - exit(1) + self.model_type = args.model_type self.model = OlmoModel(args) def __call__( diff --git a/llms/mlx_lm/models/phi2.py b/llms/mlx_lm/models/phi.py similarity index 98% rename from llms/mlx_lm/models/phi2.py rename to llms/mlx_lm/models/phi.py index 13326080..93bba876 100644 --- a/llms/mlx_lm/models/phi2.py +++ b/llms/mlx_lm/models/phi.py @@ -10,6 +10,7 @@ from .base import BaseModelArgs @dataclass class ModelArgs(BaseModelArgs): + model_type: str max_position_embeddings: int = 2048 vocab_size: int = 51200 hidden_size: int = 2560 @@ -163,6 +164,7 @@ class PhiModel(nn.Module): class Model(nn.Module): def __init__(self, config: ModelArgs): super().__init__() + self.model_type = config.model_type self.model = PhiModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) diff --git a/llms/phixtral/phixtral.py b/llms/mlx_lm/models/phixtral.py similarity index 73% rename from llms/phixtral/phixtral.py rename to llms/mlx_lm/models/phixtral.py index 0aa37ebf..14ef5d45 100644 --- a/llms/phixtral/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -8,6 +8,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn +import numpy as np from huggingface_hub import snapshot_download from mlx.utils import tree_unflatten from transformers import AutoTokenizer @@ -15,6 +16,7 @@ from transformers import AutoTokenizer @dataclass class ModelArgs: + model_type: str max_sequence_length: int = 2048 num_vocab: int = 51200 model_dim: int = 2560 @@ -110,30 +112,37 @@ class MOE(nn.Module): self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)] self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False) - def __call__(self, x) -> mx.array: + def __call__(self, x: mx.array) -> mx.array: ne = self.num_experts_per_tok orig_shape = x.shape x = x.reshape(-1, x.shape[-1]) gates = self.gate(x) - if ne < self.num_experts: - inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] - else: - inds = mx.broadcast_to(mx.arange(ne), gates.shape) - + inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1))[:, :ne] scores = mx.softmax( mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), axis=-1, ).astype(gates.dtype) - y = [] - for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1) - yt = (yt * st).sum(axis=-1) - y.append(yt[None, :]) - yc = mx.concatenate(y) + if self.training: + ys = [] + y = mx.zeros((x.shape[0], ne, x.shape[-1])) + for e, expert in enumerate(self.mlp): + idx1, idx2 = map(mx.array, np.where(inds == e)) + if idx1.size == 0: + continue + y[idx1, idx2] = expert(x[idx1]) - return yc.reshape(orig_shape) + y = (y * scores[..., None]).sum(axis=1) + else: + y = [] + for xt, st, it in zip(x, scores, inds.tolist()): + yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1) + yt = (yt * st).sum(axis=-1) + y.append(yt[None, :]) + y = mx.concatenate(y) + + return y.reshape(orig_shape) class ParallelBlock(nn.Module): @@ -190,6 +199,7 @@ class OutputHead(nn.Module): class Model(nn.Module): def __init__(self, config: ModelArgs): super().__init__() + self.model_type = config.model_type self.transformer = TransformerDecoder(config) self.lm_head = OutputHead(config) @@ -206,57 +216,3 @@ class Model(nn.Module): y, cache = self.transformer(x, mask, cache) return self.lm_head(y), cache - - -def generate(prompt: mx.array, model: Model, temp: float = 0.0): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temp)) - - y = prompt - cache = None - while True: - logits, cache = model(y[None], cache=cache) - logits = logits[:, -1, :] - y = sample(logits) - yield y - - -def load(path_or_hf_repo: str): - # If the path exists, it will try to load model form it - # otherwise download and cache from the hf_repo and cache - model_path = Path(path_or_hf_repo) - if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], - ) - ) - - with open(model_path / "config.json", "r") as f: - config = json.loads(f.read()) - quantization = config.get("quantization", None) - model_args = ModelArgs.from_dict(config) - - weight_files = glob.glob(str(model_path / "*.safetensors")) - if len(weight_files) == 0: - raise FileNotFoundError("No safetensors found in {}".format(model_path)) - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf).items()) - - model = Model(model_args) - if quantization is not None: - nn.QuantizedLinear.quantize_module(model, **quantization) - - model.load_weights(list(weights.items())) - - mx.eval(model.parameters()) - tokenizer = AutoTokenizer.from_pretrained( - model_path, - ) - return model, tokenizer diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 7f9aa070..b6ca2491 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -1,126 +1,25 @@ -from typing import Any, List, NamedTuple, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn import numpy as np -from transformers import PretrainedConfig + +from .base import BaseModelArgs -class DecoderInput(NamedTuple): - hidden_states: mx.array - position_ids: mx.array - attention_mask: Optional[mx.array] = None - past_key_values: Optional[List[mx.array]] = None - output_hidden_states: Optional[bool] = False - output_attentions: Optional[bool] = False - use_cache: Optional[bool] = False - gradient_checkpointing: bool = False - - -class DecoderOutput(NamedTuple): - hidden_states: mx.array - all_hidden_states: Optional[Tuple[mx.array, ...]] - all_self_attns: Optional[Tuple[mx.array, ...]] - next_decoder_cache: Optional[Tuple[mx.array, ...]] - - -class ModelArgs(PretrainedConfig): # type: ignore - model_type: str = "plamo" - - def __init__( - self, - vocab_size: int = 32000, - hidden_size: int = 4096, - intermediate_size: int = 13312, - num_hidden_layers: int = 32, - num_attention_heads: int = 32, - max_position_embeddings: int = 2048, - initializer_range: float = 0.02, - rms_norm_eps: float = 1e-6, - use_cache: bool = True, - tokenizer_class: str = "PlamoTokenizer", - pad_token_id: Optional[int] = None, - bos_token_id: int = 1, - eos_token_id: int = 2, - n_shared_head: int = 8, - tie_word_embeddings: bool = False, - **kwargs: Any, - ) -> None: - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.n_shared_head = n_shared_head - - super().__init__( - tokenizer_class=tokenizer_class, - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -class RotaryEmbedding: - def __init__( - self, dim: int, max_position_embeddings: int = 2048, base: int = 10000 - ) -> None: - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.inv_freq = 1.0 / mx.power( - self.base, mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim - ) - self.cos_cached = mx.zeros((1, 1, max_position_embeddings, dim)) - self.sin_cached = mx.zeros((1, 1, max_position_embeddings, dim)) - self._set_cos_sin_cache(max_position_embeddings) - - def _set_cos_sin_cache(self, seq_len: int) -> None: - self.max_seq_len_cached = seq_len - t = mx.arange(self.max_seq_len_cached) # type: ignore - - freqs = mx.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = mx.concatenate((freqs, freqs), axis=-1) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] - - def __call__(self, x: mx.array, seq_len: int) -> Tuple[mx.array, mx.array]: - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len) - - return ( - self.cos_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore - self.sin_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore - ) - - -def _rotate_half(x: mx.array) -> mx.array: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return mx.concatenate((-x2, x1), axis=-1) - - -def _rotary_pos_emb( - x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array -) -> mx.array: - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = mx.squeeze(cos, (0, 1)) # [seq_len, dim] - sin = mx.squeeze(sin, (0, 1)) # [seq_len, dim] - cos = cos[position_ids][:, None] # [bs, 1, seq_len, dim] - sin = sin[position_ids][:, None] # [bs, 1, seq_len, dim] - x_embed = (x * cos) + (_rotate_half(x) * sin) - return x_embed +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + n_shared_head: int = (8,) + rope_theta: float = 10000 + rope_traditional: bool = False class RMSNorm(nn.Module): @@ -143,7 +42,6 @@ class Attention(nn.Module): self.config = config self.hidden_size = config.hidden_size head_dim = self.hidden_size // config.num_attention_heads - self.max_position_embeddings = config.max_position_embeddings self.q_num_heads = config.num_attention_heads self.qk_dim = self.v_dim = head_dim @@ -165,15 +63,17 @@ class Attention(nn.Module): self.o_proj = nn.Linear( self.q_num_heads * self.v_dim, self.hidden_size, bias=False ) - self.rotary_emb = RotaryEmbedding( - self.qk_dim, max_position_embeddings=self.max_position_embeddings + self.rotary_emb = nn.RoPE( + head_dim, + traditional=config.rope_traditional, + base=config.rope_theta, + scale=1.0, ) def __call__( self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - position_ids: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: bsz, q_len, _ = hidden_states.shape @@ -204,13 +104,11 @@ class Attention(nn.Module): key_states = _expand_kv(key_states) value_states = _expand_kv(value_states) - kv_seq_len = key_states.shape[-2] + kv_seq_len = 0 if cache is not None: kv_seq_len += cache[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - assert position_ids is not None - query_states = _rotary_pos_emb(query_states, cos, sin, position_ids) - key_states = _rotary_pos_emb(key_states, cos, sin, position_ids) + query_states = self.rotary_emb(query_states, offset=kv_seq_len) + key_states = self.rotary_emb(key_states, offset=kv_seq_len) if cache is not None: # reuse k, v, self_attention @@ -235,10 +133,9 @@ class MLP(nn.Module): self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = nn.silu def __call__(self, x: mx.array) -> mx.array: - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # type: ignore + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore class PlamoDecoderLayer(nn.Module): @@ -254,7 +151,6 @@ class PlamoDecoderLayer(nn.Module): self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - position_ids: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> Tuple[Any, ...]: # from LlamaDecoder @@ -266,18 +162,14 @@ class PlamoDecoderLayer(nn.Module): hidden_states_sa, cache = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - position_ids=position_ids, cache=cache, ) # Fully Connected hidden_states_mlp = self.mlp(hidden_states) - # Residual ("Parallel Layers" is used here, which is different from the normal residual connection) - # See "GPT-NeoX-20B: An Open-Source Autoregressive Language Model" for Parallel Layers hidden_states = residual + hidden_states_sa + hidden_states_mlp - - return hidden_states, cache # type: ignore + return hidden_states, cache class PlamoDecoder(nn.Module): @@ -289,24 +181,14 @@ class PlamoDecoder(nn.Module): class PlamoModel(nn.Module): - config_class = ModelArgs - _no_split_modules: List[str] - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["PlamoDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - def __init__(self, config: ModelArgs): super().__init__() self.config = config - self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = PlamoDecoder(config) # type: ignore self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False def __call__( self, @@ -326,10 +208,9 @@ class PlamoModel(nn.Module): else: if cache[0] is not None: past_key_values_length = cache[0][0].shape[2] - position_ids = _create_position_ids(h.shape[1], past_key_values_length) for e, layer in enumerate(self.layers.layers): - h, c = layer(h, mask, position_ids, cache[e]) + h, c = layer(h, mask, cache[e]) if cache is not None: cache[e] = c else: @@ -338,22 +219,13 @@ class PlamoModel(nn.Module): return self.norm(h), cache -def _create_position_ids(seq_length: int, past_key_values_length: int = 0) -> mx.array: - # create position_ids on the fly for batch generation - position_ids = mx.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=mx.int64 - ) - position_ids = position_ids[None, ...].reshape(-1, seq_length) - - return position_ids - - class Model(nn.Module): - def __init__(self, config: PretrainedConfig) -> None: + def __init__(self, args: ModelArgs) -> None: super().__init__() - self.model = PlamoModel(config) + self.model_type = args.model_type + self.model = PlamoModel(args) self.lm_head: nn.Module = nn.Linear( - config.hidden_size, config.vocab_size, bias=False + args.hidden_size, args.vocab_size, bias=False ) def __call__( diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index a086a95c..aeda9c32 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -9,6 +9,7 @@ from .base import BaseModelArgs @dataclass class ModelArgs(BaseModelArgs): + model_type: str hidden_size: int = 2048 num_attention_heads: int = 16 num_hidden_layers: int = 24 @@ -160,6 +161,7 @@ class QwenModel(nn.Module): class Model(nn.Module): def __init__(self, config: ModelArgs): super().__init__() + self.model_type = config.model_type self.transformer = QwenModel(config) self.lm_head = nn.Linear( config.hidden_size, config.vocab_size, bias=not config.no_bias diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 375a5b81..41ce6d4b 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -9,6 +9,7 @@ from .base import BaseModelArgs @dataclass class ModelArgs(BaseModelArgs): + model_type: str hidden_size: int num_hidden_layers: int intermediate_size: int @@ -190,6 +191,7 @@ class Qwen2Model(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.model_type = args.model_type self.model = Qwen2Model(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) diff --git a/llms/mlx_lm/models/stablelm_epoch.py b/llms/mlx_lm/models/stablelm_epoch.py index 04d64efc..a0fe0d30 100644 --- a/llms/mlx_lm/models/stablelm_epoch.py +++ b/llms/mlx_lm/models/stablelm_epoch.py @@ -11,6 +11,7 @@ from .base import BaseModelArgs @dataclass class ModelArgs(BaseModelArgs): max_position_embeddings: int + model_type: str vocab_size: int hidden_size: int num_attention_heads: int @@ -169,6 +170,7 @@ class StableLM(nn.Module): class Model(nn.Module): def __init__(self, config: ModelArgs): super().__init__() + self.model_type = config.model_type self.model = StableLM(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 74f677fb..b7fada8a 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -7,6 +7,44 @@ from mlx.utils import tree_unflatten from .lora import LoRALinear +def linear_to_lora_layers(model: nn.Module, num_lora_layers: int): + """ + Convert some of the models linear layers to lora layers. + + Args: + model (nn.Module): The neural network model. + num_lora_layers (int): The number of blocks to convert to lora layers + starting from the last layer. + """ + if model.model_type in [ + "mistral", + "llama", + "phi", + "mixtral", + "stablelm_epoch", + "qwen2", + ]: + for l in model.model.layers[len(model.model.layers) - num_lora_layers :]: + l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) + l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) + if hasattr(l, "block_sparse_moe"): + l.block_sparse_moe.gate = LoRALinear.from_linear( + l.block_sparse_moe.gate + ) + elif model.model_type == "olmo": + for l in model.model.transformer.blocks[ + len(model.model.transformer.blocks) - num_lora_layers : + ]: + l.att_proj = LoRALinear.from_linear(l.att_proj) + elif model.model_type == "phi-msft": + for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]: + l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv) + l.moe.gate = LoRALinear.from_linear(l.moe.gate) + + else: + raise ValueError(f"Lora does not support {model.model_type}") + + def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module: """ Apply LoRA layers to the model. diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 32c9b7b4..f4112b6a 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -1,5 +1,6 @@ import copy import glob +import importlib import json import logging import time @@ -12,28 +13,14 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer # Local imports -from .models import llama, mixtral, olmo, phi2, plamo, qwen, qwen2, stablelm_epoch from .tuner.utils import apply_lora_layers # Constants -MODEL_MAPPING = { - "llama": llama, - "mistral": llama, # mistral is compatible with llama - "mixtral": mixtral, - "phi": phi2, - "stablelm_epoch": stablelm_epoch, - "qwen": qwen, - "plamo": plamo, - "olmo": olmo, - "qwen2": qwen2, +MODEL_REMAPPING = { + "mistral": "llama", # mistral is compatible with llama + "phi-msft": "phixtral", } -LORA_SUPPORTED_MODELS = [ - llama.Model, - mixtral.Model, - phi2.Model, - stablelm_epoch.Model, - qwen2.Model, -] + MAX_FILE_SIZE_GB = 5 linear_class_predicate = ( @@ -54,12 +41,14 @@ def _get_classes(config: dict): A tuple containing the Model class and the ModelArgs class. """ model_type = config["model_type"] - if model_type not in MODEL_MAPPING: + model_type = MODEL_REMAPPING.get(model_type, model_type) + try: + arch = importlib.import_module(f"mlx_lm.models.{model_type}") + except ImportError: msg = f"Model type {model_type} not supported." logging.error(msg) raise ValueError(msg) - arch = MODEL_MAPPING[model_type] return arch.Model, arch.ModelArgs diff --git a/llms/phixtral/README.md b/llms/phixtral/README.md deleted file mode 100644 index 93524c0a..00000000 --- a/llms/phixtral/README.md +++ /dev/null @@ -1,28 +0,0 @@ -# Phixtral - -Phixtral is a Mixture of Experts (MoE) architecture inspired by -[Mixtral](../mixtral/README.md) but made by combinding fine-tuned versions of -Phi-2.[^1][^2] - -### Setup - -Install the dependencies: - -``` -pip install -r requirements.txt -``` - -### Run - -``` -python generate.py \ - --model mlabonne/phixtral-4x2_8 \ - --prompt "write a quick sort in Python" -``` - -Run `python generate.py --help` to see all the options. - -[^1]: For more details on Phixtral, see the [Hugging Face repo](https://huggingface.co/mlabonne/phixtral-4x2_8). -[^2]: For more details on Phi-2 see Microsoft's [blog post]( -https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) -and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2). diff --git a/llms/phixtral/generate.py b/llms/phixtral/generate.py deleted file mode 100644 index e1767e34..00000000 --- a/llms/phixtral/generate.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import argparse -import time - -import mlx.core as mx -import phixtral -import transformers - - -def generate( - model: phixtral.Model, - tokenizer: transformers.AutoTokenizer, - prompt: str, - max_tokens: int, - temp: float = 0.0, -): - print("[INFO] Generating with Phixtral...", flush=True) - print(prompt, end="", flush=True) - prompt = tokenizer( - prompt, - return_tensors="np", - return_attention_mask=False, - )[ - "input_ids" - ][0] - prompt = mx.array(prompt) - - tic = time.time() - tokens = [] - skip = 0 - for token, n in zip( - phixtral.generate(prompt, model, temp), - range(max_tokens), - ): - if token == tokenizer.eos_token_id: - break - - if n == 0: - prompt_time = time.time() - tic - tic = time.time() - - tokens.append(token.item()) - # if (n + 1) % 10 == 0: - s = tokenizer.decode(tokens) - print(s[skip:], end="", flush=True) - skip = len(s) - print(tokenizer.decode(tokens)[skip:], flush=True) - gen_time = time.time() - tic - print("=" * 10) - if len(tokens) == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt.size / prompt_time - gen_tps = (len(tokens) - 1) / gen_time - print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {gen_tps:.3f} tokens-per-sec") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="inference script") - parser.add_argument( - "--model", - type=str, - default="mlx_model", - help="The path to the local model directory or Hugging Face repo.", - ) - parser.add_argument( - "--prompt", - help="The message to be processed by the model", - default="Write a detailed analogy between mathematics and a lighthouse.", - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temp", - help="The sampling temperature.", - type=float, - default=0.0, - ) - parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") - - args = parser.parse_args() - mx.random.seed(args.seed) - model, tokenizer = phixtral.load(args.model) - generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) diff --git a/llms/phixtral/requirements.txt b/llms/phixtral/requirements.txt deleted file mode 100644 index 016af3ae..00000000 --- a/llms/phixtral/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -einops -hf_transfer -huggingface_hub -mlx -numpy -torch -transformers>=4.35