mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-11-06 06:28:10 +08:00
Added lora support for Phi-2 (#302)
* Added lora support for Phi-2 * Added Phi-2 support in fuse and convert * format + readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
100
lora/utils.py
100
lora/utils.py
@@ -2,12 +2,44 @@
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import models.llama as llama
|
||||
import models.phi2 as phi2
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Constants
|
||||
MODEL_MAPPING = {
|
||||
"llama": llama,
|
||||
"mistral": llama, # mistral is compatible with llama
|
||||
"phi": phi2,
|
||||
}
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Model class and the ModelArgs class.
|
||||
"""
|
||||
model_type = config["model_type"]
|
||||
if model_type not in MODEL_MAPPING:
|
||||
msg = f"Model type {model_type} not supported."
|
||||
logging.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
arch = MODEL_MAPPING[model_type]
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
def fetch_from_hub(hf_path: str):
|
||||
model_path = snapshot_download(
|
||||
@@ -88,3 +120,71 @@ def save_model(save_dir: str, weights, tokenizer, config):
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
with open(save_dir / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str):
|
||||
# If the path exists, it will try to load model form it
|
||||
# otherwise download and cache from the hf_repo and cache
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
quantization = config.get("quantization", None)
|
||||
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = model_class(model_args)
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
return model, tokenizer, config
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: mx.array, model: nn.Module, temp: float = 0.0
|
||||
) -> Generator[mx.array, None, None]:
|
||||
"""
|
||||
Generate text based on the given prompt and model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
||||
|
||||
Yields:
|
||||
mx.array: The generated text.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> mx.array:
|
||||
return (
|
||||
mx.argmax(logits, axis=-1)
|
||||
if temp == 0
|
||||
else mx.random.categorical(logits * (1 / temp))
|
||||
)
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits)
|
||||
yield y
|
||||
|
||||
Reference in New Issue
Block a user