Added lora support for Phi-2 (#302)

* Added lora support for Phi-2

* Added Phi-2 support in fuse and convert

* format + readme

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Yousif
2024-01-12 13:45:30 -08:00
committed by GitHub
parent 3ac731dd4f
commit 7575125d5d
12 changed files with 564 additions and 25 deletions

View File

@@ -2,12 +2,44 @@
import glob
import json
import logging
from pathlib import Path
from typing import Generator
import mlx.core as mx
import mlx.nn as nn
import models.llama as llama
import models.phi2 as phi2
import transformers
from huggingface_hub import snapshot_download
# Constants
MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"phi": phi2,
}
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
if model_type not in MODEL_MAPPING:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
arch = MODEL_MAPPING[model_type]
return arch.Model, arch.ModelArgs
def fetch_from_hub(hf_path: str):
model_path = snapshot_download(
@@ -88,3 +120,71 @@ def save_model(save_dir: str, weights, tokenizer, config):
tokenizer.save_pretrained(save_dir)
with open(save_dir / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model_class, model_args_class = _get_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
return model, tokenizer, config
def generate(
prompt: mx.array, model: nn.Module, temp: float = 0.0
) -> Generator[mx.array, None, None]:
"""
Generate text based on the given prompt and model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling. If temp is 0, use max sampling.
Yields:
mx.array: The generated text.
"""
def sample(logits: mx.array) -> mx.array:
return (
mx.argmax(logits, axis=-1)
if temp == 0
else mx.random.categorical(logits * (1 / temp))
)
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y