Lazy import + refactor Lora layer addition (#426)

* lazy model import in mlx_lm

* change lora loading

* fix olmo lora

* remove a bunch of unused stuff from plamo

* move phixtral to mlx-lm and out of llms/
This commit is contained in:
Awni Hannun 2024-02-12 10:51:02 -08:00 committed by GitHub
parent 4576946151
commit d4666615bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 127 additions and 393 deletions

View File

@ -9,7 +9,8 @@ from mlx.utils import tree_flatten
from .tuner.lora import LoRALinear
from .tuner.trainer import TrainingArgs, evaluate, train
from .utils import LORA_SUPPORTED_MODELS, generate, load
from .tuner.utils import linear_to_lora_layers
from .utils import generate, load
def build_parser():
@ -169,19 +170,10 @@ if __name__ == "__main__":
print("Loading pretrained model")
model, tokenizer = load(args.model)
if model.__class__ not in LORA_SUPPORTED_MODELS:
raise ValueError(
f"Model {model.__class__} not supported. "
f"Supported models: {LORA_SUPPORTED_MODELS}"
)
# Freeze all layers other than LORA linears
# Freeze all layers
model.freeze()
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
if hasattr(l, "block_sparse_moe"):
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")

View File

@ -9,6 +9,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
@ -18,7 +19,6 @@ class ModelArgs(BaseModelArgs):
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
@ -190,6 +190,7 @@ class LlamaModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

View File

@ -21,9 +21,9 @@ class ModelArgs(BaseModelArgs):
num_local_experts: int = 8
rms_norm_eps: float = 1e-5
vocab_size: int
model_type: str
rope_theta: float = 1e6
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
@ -252,6 +252,7 @@ class MixtralModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = MixtralModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

View File

@ -7,18 +7,25 @@ import mlx.nn as nn
from .base import BaseModelArgs
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
d_model: int
n_layers: int
mlp_hidden_size: int
n_heads: int
vocab_size: int
embedding_size: int
model_type: str
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
mlp_ratio: int = 4
weight_tying: bool = False
@ -162,11 +169,7 @@ class OlmoModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
self.model_type = args.model_type
self.model = OlmoModel(args)
def __call__(

View File

@ -10,6 +10,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
max_position_embeddings: int = 2048
vocab_size: int = 51200
hidden_size: int = 2560
@ -163,6 +164,7 @@ class PhiModel(nn.Module):
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.model = PhiModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

View File

@ -8,6 +8,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@ -15,6 +16,7 @@ from transformers import AutoTokenizer
@dataclass
class ModelArgs:
model_type: str
max_sequence_length: int = 2048
num_vocab: int = 51200
model_dim: int = 2560
@ -110,30 +112,37 @@ class MOE(nn.Module):
self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)]
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
def __call__(self, x) -> mx.array:
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
if ne < self.num_experts:
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
else:
inds = mx.broadcast_to(mx.arange(ne), gates.shape)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1))[:, :ne]
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
yc = mx.concatenate(y)
if self.training:
ys = []
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.mlp):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
return yc.reshape(orig_shape)
y = (y * scores[..., None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
class ParallelBlock(nn.Module):
@ -190,6 +199,7 @@ class OutputHead(nn.Module):
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
@ -206,57 +216,3 @@ class Model(nn.Module):
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
model_args = ModelArgs.from_dict(config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
return model, tokenizer

View File

@ -1,126 +1,25 @@
from typing import Any, List, NamedTuple, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from transformers import PretrainedConfig
from .base import BaseModelArgs
class DecoderInput(NamedTuple):
hidden_states: mx.array
position_ids: mx.array
attention_mask: Optional[mx.array] = None
past_key_values: Optional[List[mx.array]] = None
output_hidden_states: Optional[bool] = False
output_attentions: Optional[bool] = False
use_cache: Optional[bool] = False
gradient_checkpointing: bool = False
class DecoderOutput(NamedTuple):
hidden_states: mx.array
all_hidden_states: Optional[Tuple[mx.array, ...]]
all_self_attns: Optional[Tuple[mx.array, ...]]
next_decoder_cache: Optional[Tuple[mx.array, ...]]
class ModelArgs(PretrainedConfig): # type: ignore
model_type: str = "plamo"
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 4096,
intermediate_size: int = 13312,
num_hidden_layers: int = 32,
num_attention_heads: int = 32,
max_position_embeddings: int = 2048,
initializer_range: float = 0.02,
rms_norm_eps: float = 1e-6,
use_cache: bool = True,
tokenizer_class: str = "PlamoTokenizer",
pad_token_id: Optional[int] = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
n_shared_head: int = 8,
tie_word_embeddings: bool = False,
**kwargs: Any,
) -> None:
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.n_shared_head = n_shared_head
super().__init__(
tokenizer_class=tokenizer_class,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class RotaryEmbedding:
def __init__(
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
) -> None:
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / mx.power(
self.base, mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim
)
self.cos_cached = mx.zeros((1, 1, max_position_embeddings, dim))
self.sin_cached = mx.zeros((1, 1, max_position_embeddings, dim))
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len: int) -> None:
self.max_seq_len_cached = seq_len
t = mx.arange(self.max_seq_len_cached) # type: ignore
freqs = mx.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = mx.concatenate((freqs, freqs), axis=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
def __call__(self, x: mx.array, seq_len: int) -> Tuple[mx.array, mx.array]:
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
return (
self.cos_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
self.sin_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
)
def _rotate_half(x: mx.array) -> mx.array:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate((-x2, x1), axis=-1)
def _rotary_pos_emb(
x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array
) -> mx.array:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = mx.squeeze(cos, (0, 1)) # [seq_len, dim]
sin = mx.squeeze(sin, (0, 1)) # [seq_len, dim]
cos = cos[position_ids][:, None] # [bs, 1, seq_len, dim]
sin = sin[position_ids][:, None] # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (_rotate_half(x) * sin)
return x_embed
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
n_shared_head: int = (8,)
rope_theta: float = 10000
rope_traditional: bool = False
class RMSNorm(nn.Module):
@ -143,7 +42,6 @@ class Attention(nn.Module):
self.config = config
self.hidden_size = config.hidden_size
head_dim = self.hidden_size // config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.q_num_heads = config.num_attention_heads
self.qk_dim = self.v_dim = head_dim
@ -165,15 +63,17 @@ class Attention(nn.Module):
self.o_proj = nn.Linear(
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
)
self.rotary_emb = RotaryEmbedding(
self.qk_dim, max_position_embeddings=self.max_position_embeddings
self.rotary_emb = nn.RoPE(
head_dim,
traditional=config.rope_traditional,
base=config.rope_theta,
scale=1.0,
)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
bsz, q_len, _ = hidden_states.shape
@ -204,13 +104,11 @@ class Attention(nn.Module):
key_states = _expand_kv(key_states)
value_states = _expand_kv(value_states)
kv_seq_len = key_states.shape[-2]
kv_seq_len = 0
if cache is not None:
kv_seq_len += cache[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
assert position_ids is not None
query_states = _rotary_pos_emb(query_states, cos, sin, position_ids)
key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
query_states = self.rotary_emb(query_states, offset=kv_seq_len)
key_states = self.rotary_emb(key_states, offset=kv_seq_len)
if cache is not None:
# reuse k, v, self_attention
@ -235,10 +133,9 @@ class MLP(nn.Module):
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
class PlamoDecoderLayer(nn.Module):
@ -254,7 +151,6 @@ class PlamoDecoderLayer(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[Any, ...]:
# from LlamaDecoder
@ -266,18 +162,14 @@ class PlamoDecoderLayer(nn.Module):
hidden_states_sa, cache = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
cache=cache,
)
# Fully Connected
hidden_states_mlp = self.mlp(hidden_states)
# Residual ("Parallel Layers" is used here, which is different from the normal residual connection)
# See "GPT-NeoX-20B: An Open-Source Autoregressive Language Model" for Parallel Layers
hidden_states = residual + hidden_states_sa + hidden_states_mlp
return hidden_states, cache # type: ignore
return hidden_states, cache
class PlamoDecoder(nn.Module):
@ -289,24 +181,14 @@ class PlamoDecoder(nn.Module):
class PlamoModel(nn.Module):
config_class = ModelArgs
_no_split_modules: List[str]
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PlamoDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = PlamoDecoder(config) # type: ignore
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
def __call__(
self,
@ -326,10 +208,9 @@ class PlamoModel(nn.Module):
else:
if cache[0] is not None:
past_key_values_length = cache[0][0].shape[2]
position_ids = _create_position_ids(h.shape[1], past_key_values_length)
for e, layer in enumerate(self.layers.layers):
h, c = layer(h, mask, position_ids, cache[e])
h, c = layer(h, mask, cache[e])
if cache is not None:
cache[e] = c
else:
@ -338,22 +219,13 @@ class PlamoModel(nn.Module):
return self.norm(h), cache
def _create_position_ids(seq_length: int, past_key_values_length: int = 0) -> mx.array:
# create position_ids on the fly for batch generation
position_ids = mx.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=mx.int64
)
position_ids = position_ids[None, ...].reshape(-1, seq_length)
return position_ids
class Model(nn.Module):
def __init__(self, config: PretrainedConfig) -> None:
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.model = PlamoModel(config)
self.model_type = args.model_type
self.model = PlamoModel(args)
self.lm_head: nn.Module = nn.Linear(
config.hidden_size, config.vocab_size, bias=False
args.hidden_size, args.vocab_size, bias=False
)
def __call__(

View File

@ -9,6 +9,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 2048
num_attention_heads: int = 16
num_hidden_layers: int = 24
@ -160,6 +161,7 @@ class QwenModel(nn.Module):
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.transformer = QwenModel(config)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=not config.no_bias

View File

@ -9,6 +9,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
@ -190,6 +191,7 @@ class Qwen2Model(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Qwen2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

View File

@ -11,6 +11,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
max_position_embeddings: int
model_type: str
vocab_size: int
hidden_size: int
num_attention_heads: int
@ -169,6 +170,7 @@ class StableLM(nn.Module):
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.model = StableLM(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

View File

@ -7,6 +7,44 @@ from mlx.utils import tree_unflatten
from .lora import LoRALinear
def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
"""
Convert some of the models linear layers to lora layers.
Args:
model (nn.Module): The neural network model.
num_lora_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
"""
if model.model_type in [
"mistral",
"llama",
"phi",
"mixtral",
"stablelm_epoch",
"qwen2",
]:
for l in model.model.layers[len(model.model.layers) - num_lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
if hasattr(l, "block_sparse_moe"):
l.block_sparse_moe.gate = LoRALinear.from_linear(
l.block_sparse_moe.gate
)
elif model.model_type == "olmo":
for l in model.model.transformer.blocks[
len(model.model.transformer.blocks) - num_lora_layers :
]:
l.att_proj = LoRALinear.from_linear(l.att_proj)
elif model.model_type == "phi-msft":
for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]:
l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv)
l.moe.gate = LoRALinear.from_linear(l.moe.gate)
else:
raise ValueError(f"Lora does not support {model.model_type}")
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module:
"""
Apply LoRA layers to the model.

View File

@ -1,5 +1,6 @@
import copy
import glob
import importlib
import json
import logging
import time
@ -12,28 +13,14 @@ from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
# Local imports
from .models import llama, mixtral, olmo, phi2, plamo, qwen, qwen2, stablelm_epoch
from .tuner.utils import apply_lora_layers
# Constants
MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"mixtral": mixtral,
"phi": phi2,
"stablelm_epoch": stablelm_epoch,
"qwen": qwen,
"plamo": plamo,
"olmo": olmo,
"qwen2": qwen2,
MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
}
LORA_SUPPORTED_MODELS = [
llama.Model,
mixtral.Model,
phi2.Model,
stablelm_epoch.Model,
qwen2.Model,
]
MAX_FILE_SIZE_GB = 5
linear_class_predicate = (
@ -54,12 +41,14 @@ def _get_classes(config: dict):
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
if model_type not in MODEL_MAPPING:
model_type = MODEL_REMAPPING.get(model_type, model_type)
try:
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
except ImportError:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
arch = MODEL_MAPPING[model_type]
return arch.Model, arch.ModelArgs

View File

@ -1,28 +0,0 @@
# Phixtral
Phixtral is a Mixture of Experts (MoE) architecture inspired by
[Mixtral](../mixtral/README.md) but made by combinding fine-tuned versions of
Phi-2.[^1][^2]
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
### Run
```
python generate.py \
--model mlabonne/phixtral-4x2_8 \
--prompt "write a quick sort in Python"
```
Run `python generate.py --help` to see all the options.
[^1]: For more details on Phixtral, see the [Hugging Face repo](https://huggingface.co/mlabonne/phixtral-4x2_8).
[^2]: For more details on Phi-2 see Microsoft's [blog post](
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2).

View File

@ -1,91 +0,0 @@
# Copyright © 2023 Apple Inc.
import argparse
import time
import mlx.core as mx
import phixtral
import transformers
def generate(
model: phixtral.Model,
tokenizer: transformers.AutoTokenizer,
prompt: str,
max_tokens: int,
temp: float = 0.0,
):
print("[INFO] Generating with Phixtral...", flush=True)
print(prompt, end="", flush=True)
prompt = tokenizer(
prompt,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
prompt = mx.array(prompt)
tic = time.time()
tokens = []
skip = 0
for token, n in zip(
phixtral.generate(prompt, model, temp),
range(max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
# if (n + 1) % 10 == 0:
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="inference script")
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = phixtral.load(args.model)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)

View File

@ -1,7 +0,0 @@
einops
hf_transfer
huggingface_hub
mlx
numpy
torch
transformers>=4.35