refactor: make the phi2 example can be directly load the model from hf without convert needed (#253)

* refactor: make the phi2 example can be directly load the model from hf without convert needed

* chore: add super().__init__() for all module, otherwise will cause error in lora
This commit is contained in:
Anchen 2024-01-08 06:01:23 -08:00 committed by GitHub
parent 9742ad0f51
commit 6e5b0de4d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 313 additions and 170 deletions

View File

@ -7,63 +7,52 @@ GPT-4 outputs and clean web text.
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
precision. precision.
## Setup ### Setup
Download and convert the model: Install the dependencies:
```sh
python convert.py
```
To generate a 4-bit quantized model use the `-q` flag:
``` ```
python convert.py -q pip install -r requirements.txt
``` ```
By default, the conversion script will make the directory `mlx_model` and save ### Run
the converted `weights.npz`, and `config.json` there.
> [!TIP] Alternatively, you can also download a few converted checkpoints from
> the [MLX Community](https://huggingface.co/mlx-community) organization on
> Hugging Face and skip the conversion step.
## Generate
To generate text with the default prompt:
```sh
python phi2.py
``` ```
python generate.py --model <model_path> --prompt "hello"
Should give the output: ```
For example:
``` ```
Answer: Mathematics is like a lighthouse that guides us through the darkness of python generate.py --model microsoft/phi-2 --prompt "hello"
uncertainty. Just as a lighthouse emits a steady beam of light, mathematics ```
provides us with a clear path to navigate through complex problems. It The `<model_path>` should be either a path to a local directory or a Hugging
illuminates our understanding and helps us make sense of the world around us. Face repo with weights stored in `safetensors` format. If you use a repo from
the Hugging Face Hub, then the model will be downloaded and cached the first
time you run it.
Exercise 2: Run `python generate.py --help` to see all the options.
Compare and contrast the role of logic in mathematics and the role of a compass
in navigation.
Answer: Logic in mathematics is like a compass in navigation. It helps ### Convert new models
You can convert (change the data type or quantize) models using the
`convert.py` script. This script takes a Hugging Face repo as input and outputs
a model directory (which you can optionally also upload to Hugging Face).
For example, to make 4-bit quantized a model, run:
```
python convert.py --hf-path <hf_repo> -q
``` ```
To use your own prompt: For more options run:
```sh ```
python phi2.py --prompt <your prompt here> --max-tokens <max_tokens_to_generate> python convert.py --help
``` ```
To see a list of options run: You can upload new models to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community) by specifying `--upload-name``
```sh to `convert.py`.
python phi2.py --help
```
[^1]: For more details on the model see the [blog post]( [^1]: For more details on the model see the [blog post](
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2) and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)

View File

@ -1,23 +1,43 @@
import argparse import argparse
import copy import copy
import glob
import json import json
from pathlib import Path from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import transformers
from mlx.utils import tree_flatten, tree_map, tree_unflatten from huggingface_hub import snapshot_download
from phi2 import ModelArgs, Phi2 from mlx.utils import tree_flatten
from transformers import AutoModelForCausalLM from phi2 import Model, ModelArgs
def fetch_from_hub(hf_path: str):
model_path = snapshot_download(
repo_id=hf_path,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
config = transformers.AutoConfig.from_pretrained(hf_path, trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(
hf_path,
)
return weights, config.to_dict(), tokenizer
def quantize(weights, config, args): def quantize(weights, config, args):
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
# Load the model: # Load the model:
model = Phi2(ModelArgs()) model = Model(ModelArgs.from_dict(config))
weights = tree_map(mx.array, weights) model.load_weights(list(weights.items()))
model.update(tree_unflatten(list(weights.items())))
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
@ -32,22 +52,69 @@ def quantize(weights, config, args):
return quantized_weights, quantized_config return quantized_weights, quantized_config
def replace_key(key: str) -> str: def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
if "wte.weight" in key: max_file_size_bytes = max_file_size_gibibyte << 30
key = "wte.weight" shards = []
shard, shard_size = {}, 0
if ".mlp" in key: for k, v in weights.items():
key = key.replace(".mlp", "") estimated_size = v.size * v.dtype.size
return key if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
def convert(): def upload_to_hub(path: str, name: str, hf_path: str):
parser = argparse.ArgumentParser(description="Convert Phi-2 weights to MLX") import os
from huggingface_hub import HfApi, ModelCard, logging
repo_id = f"mlx-community/{name}"
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {name}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install mlx
git clone https://github.com/ml-explore/mlx-examples.git
cd mlx-examples/llms/hf_llm
python generate.py --model {repo_id} --prompt "My name is"
```
"""
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=repo_id,
repo_type="model",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument(
"--hf-path",
type=str,
help="Path to the Hugging Face model.",
)
parser.add_argument( parser.add_argument(
"--mlx-path", "--mlx-path",
type=str, type=str,
default="mlx_model", default="mlx_model",
help="The path to save the MLX model.", help="Path to save the MLX model.",
) )
parser.add_argument( parser.add_argument(
"-q", "-q",
@ -67,26 +134,39 @@ def convert():
type=int, type=int,
default=4, default=4,
) )
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
args = parser.parse_args() args = parser.parse_args()
print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(args.hf_path)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
mlx_path = Path(args.mlx_path) mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True) mlx_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
model = AutoModelForCausalLM.from_pretrained( for i, shard in enumerate(shards):
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
) tokenizer.save_pretrained(mlx_path)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
params = {}
if args.quantize:
print("[INFO] Quantizing")
weights, params = quantize(weights, params, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
with open(mlx_path / "config.json", "w") as fid: with open(mlx_path / "config.json", "w") as fid:
params["model_type"] = "phi2" json.dump(config, fid, indent=4)
json.dump(params, fid, indent=4)
if args.upload_name is not None:
if __name__ == "__main__": upload_to_hub(mlx_path, args.upload_name, args.hf_path)
convert()

91
llms/phi2/generate.py Normal file
View File

@ -0,0 +1,91 @@
# Copyright © 2023 Apple Inc.
import argparse
import time
import mlx.core as mx
import phi2
import transformers
def generate(
model: phi2.Model,
tokenizer: transformers.AutoTokenizer,
prompt: str,
max_tokens: int,
temp: float = 0.0,
):
print("[INFO] Generating with Phi-2...", flush=True)
print(args.prompt, end="", flush=True)
prompt = tokenizer(
prompt,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
prompt = mx.array(prompt)
tic = time.time()
tokens = []
skip = 0
for token, n in zip(
phi2.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
# if (n + 1) % 10 == 0:
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="inference script")
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = phi2.load(args.model)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)

View File

@ -1,4 +1,6 @@
import argparse import argparse
import glob
import inspect
import json import json
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@ -7,6 +9,7 @@ from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -20,6 +23,16 @@ class ModelArgs:
num_layers: int = 32 num_layers: int = 32
rotary_dim: int = 32 rotary_dim: int = 32
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
@ -75,6 +88,17 @@ class RoPEAttention(nn.Module):
return self.out_proj(values_hat), (keys, values) return self.out_proj(values_hat), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class ParallelBlock(nn.Module): class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
@ -82,23 +106,23 @@ class ParallelBlock(nn.Module):
mlp_dims = dims * 4 mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims) self.ln = LayerNorm(dims)
self.fc1 = nn.Linear(dims, mlp_dims) self.mlp = MLP(dims, mlp_dims)
self.fc2 = nn.Linear(mlp_dims, dims)
self.act = nn.GELU(approx="precise")
def __call__(self, x, mask, cache): def __call__(self, x, mask, cache):
h = self.ln(x) h = self.ln(x)
attn_h, cache = self.mixer(h, mask, cache) attn_h, cache = self.mixer(h, mask, cache)
ff_h = self.fc2(self.act(self.fc1(h))) ff_h = self.mlp(h)
return attn_h + ff_h + x, cache return attn_h + ff_h + x, cache
class TransformerDecoder(nn.Module): class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
self.embd = Embd(config)
self.h = [ParallelBlock(config) for i in range(config.num_layers)] self.h = [ParallelBlock(config) for i in range(config.num_layers)]
def __call__(self, x, mask, cache): def __call__(self, x, mask, cache):
x = self.embd(x)
if cache is None: if cache is None:
cache = [None] * len(self.h) cache = [None] * len(self.h)
@ -107,8 +131,18 @@ class TransformerDecoder(nn.Module):
return x, cache return x, cache
class Embd(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
def __call__(self, x):
return self.wte(x)
class OutputHead(nn.Module): class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.ln = LayerNorm(config.model_dim) self.ln = LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab) self.linear = nn.Linear(config.model_dim, config.num_vocab)
@ -116,20 +150,18 @@ class OutputHead(nn.Module):
return self.linear(self.ln(inputs)) return self.linear(self.ln(inputs))
class Phi2(nn.Module): class Model(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
self.wte = nn.Embedding(config.num_vocab, config.model_dim) super().__init__()
self.transformer = TransformerDecoder(config) self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config) self.lm_head = OutputHead(config)
def __call__( def __call__(
self, self,
inputs: mx.array, x: mx.array,
mask: mx.array = None, mask: mx.array = None,
cache: mx.array = None, cache: mx.array = None,
) -> tuple[mx.array, mx.array]: ) -> tuple[mx.array, mx.array]:
x = self.wte(inputs)
mask = None mask = None
if x.shape[1] > 1: if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
@ -139,104 +171,55 @@ class Phi2(nn.Module):
return self.lm_head(y), cache return self.lm_head(y), cache
def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits): def sample(logits):
if temp == 0: if temp == 0:
return mx.argmax(logits, axis=-1) return mx.argmax(logits, axis=-1)
else: else:
return mx.random.categorical(logits * (1 / temp)) return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt) y = prompt
y = sample(logits[:, -1, :]) cache = None
yield y
while True: while True:
logits, cache = model(y[:, None], cache=cache) logits, cache = model(y[None], cache=cache)
y = sample(logits.squeeze(1)) logits = logits[:, -1, :]
y = sample(logits)
yield y yield y
def load_model(model_path: str): def load(path_or_hf_repo: str):
model = Phi2(ModelArgs()) # If the path exists, it will try to load model form it
model_path = Path(model_path) # otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f: with open(model_path / "config.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
config.pop("model_type", None) quantization = config.get("quantization", None)
quantization = config.pop("quantization", None) model_args = ModelArgs.from_dict(config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items())) weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
return model, tokenizer return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the model weights",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model_path)
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
print("[INFO] Generating with Phi-2...", flush=True)
print(args.prompt, end="", flush=True)
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)