From 6e5b0de4d350c6ae740c133fc6b4733b4b464b06 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Mon, 8 Jan 2024 06:01:23 -0800
Subject: [PATCH] refactor: make the phi2 example can be directly load the
model from hf without convert needed (#253)
* refactor: make the phi2 example can be directly load the model from hf without convert needed
* chore: add super().__init__() for all module, otherwise will cause error in lora
---
llms/phi2/README.md | 73 ++++++++----------
llms/phi2/convert.py | 148 +++++++++++++++++++++++++++---------
llms/phi2/generate.py | 91 ++++++++++++++++++++++
llms/phi2/phi2.py | 171 +++++++++++++++++++-----------------------
4 files changed, 313 insertions(+), 170 deletions(-)
create mode 100644 llms/phi2/generate.py
diff --git a/llms/phi2/README.md b/llms/phi2/README.md
index c79dd5e8..086cd17e 100644
--- a/llms/phi2/README.md
+++ b/llms/phi2/README.md
@@ -7,63 +7,52 @@ GPT-4 outputs and clean web text.
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
precision.
-## Setup
+### Setup
-Download and convert the model:
-
-```sh
-python convert.py
-```
-
-To generate a 4-bit quantized model use the `-q` flag:
+Install the dependencies:
```
-python convert.py -q
+pip install -r requirements.txt
```
-By default, the conversion script will make the directory `mlx_model` and save
-the converted `weights.npz`, and `config.json` there.
-
-> [!TIP] Alternatively, you can also download a few converted checkpoints from
-> the [MLX Community](https://huggingface.co/mlx-community) organization on
-> Hugging Face and skip the conversion step.
-
-
-## Generate
-
-To generate text with the default prompt:
-
-```sh
-python phi2.py
+### Run
```
-
-Should give the output:
+python generate.py --model --prompt "hello"
+```
+For example:
```
-Answer: Mathematics is like a lighthouse that guides us through the darkness of
-uncertainty. Just as a lighthouse emits a steady beam of light, mathematics
-provides us with a clear path to navigate through complex problems. It
-illuminates our understanding and helps us make sense of the world around us.
+python generate.py --model microsoft/phi-2 --prompt "hello"
+```
+The `` should be either a path to a local directory or a Hugging
+Face repo with weights stored in `safetensors` format. If you use a repo from
+the Hugging Face Hub, then the model will be downloaded and cached the first
+time you run it.
-Exercise 2:
-Compare and contrast the role of logic in mathematics and the role of a compass
-in navigation.
+Run `python generate.py --help` to see all the options.
-Answer: Logic in mathematics is like a compass in navigation. It helps
+### Convert new models
+
+You can convert (change the data type or quantize) models using the
+`convert.py` script. This script takes a Hugging Face repo as input and outputs
+a model directory (which you can optionally also upload to Hugging Face).
+
+For example, to make 4-bit quantized a model, run:
+
+```
+python convert.py --hf-path -q
```
-To use your own prompt:
+For more options run:
-```sh
-python phi2.py --prompt --max-tokens
+```
+python convert.py --help
```
-To see a list of options run:
-
-```sh
-python phi2.py --help
-```
+You can upload new models to the [Hugging Face MLX
+Community](https://huggingface.co/mlx-community) by specifying `--upload-name``
+to `convert.py`.
[^1]: For more details on the model see the [blog post](
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
-and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)
+and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)
\ No newline at end of file
diff --git a/llms/phi2/convert.py b/llms/phi2/convert.py
index 0cb5e519..4cac6e82 100644
--- a/llms/phi2/convert.py
+++ b/llms/phi2/convert.py
@@ -1,23 +1,43 @@
import argparse
import copy
+import glob
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
-import numpy as np
-from mlx.utils import tree_flatten, tree_map, tree_unflatten
-from phi2 import ModelArgs, Phi2
-from transformers import AutoModelForCausalLM
+import transformers
+from huggingface_hub import snapshot_download
+from mlx.utils import tree_flatten
+from phi2 import Model, ModelArgs
+
+
+def fetch_from_hub(hf_path: str):
+ model_path = snapshot_download(
+ repo_id=hf_path,
+ allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
+ )
+ weight_files = glob.glob(f"{model_path}/*.safetensors")
+ if len(weight_files) == 0:
+ raise FileNotFoundError("No safetensors found in {}".format(model_path))
+
+ weights = {}
+ for wf in weight_files:
+ weights.update(mx.load(wf).items())
+
+ config = transformers.AutoConfig.from_pretrained(hf_path, trust_remote_code=True)
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ hf_path,
+ )
+ return weights, config.to_dict(), tokenizer
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
- model = Phi2(ModelArgs())
- weights = tree_map(mx.array, weights)
- model.update(tree_unflatten(list(weights.items())))
+ model = Model(ModelArgs.from_dict(config))
+ model.load_weights(list(weights.items()))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
@@ -32,22 +52,69 @@ def quantize(weights, config, args):
return quantized_weights, quantized_config
-def replace_key(key: str) -> str:
- if "wte.weight" in key:
- key = "wte.weight"
-
- if ".mlp" in key:
- key = key.replace(".mlp", "")
- return key
+def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
+ max_file_size_bytes = max_file_size_gibibyte << 30
+ shards = []
+ shard, shard_size = {}, 0
+ for k, v in weights.items():
+ estimated_size = v.size * v.dtype.size
+ if shard_size + estimated_size > max_file_size_bytes:
+ shards.append(shard)
+ shard, shard_size = {}, 0
+ shard[k] = v
+ shard_size += estimated_size
+ shards.append(shard)
+ return shards
-def convert():
- parser = argparse.ArgumentParser(description="Convert Phi-2 weights to MLX")
+def upload_to_hub(path: str, name: str, hf_path: str):
+ import os
+
+ from huggingface_hub import HfApi, ModelCard, logging
+
+ repo_id = f"mlx-community/{name}"
+
+ card = ModelCard.load(hf_path)
+ card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
+ card.text = f"""
+# {name}
+This model was converted to MLX format from [`{hf_path}`]().
+Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
+## Use with mlx
+```bash
+pip install mlx
+git clone https://github.com/ml-explore/mlx-examples.git
+cd mlx-examples/llms/hf_llm
+python generate.py --model {repo_id} --prompt "My name is"
+```
+"""
+ card.save(os.path.join(path, "README.md"))
+
+ logging.set_verbosity_info()
+
+ api = HfApi()
+ api.create_repo(repo_id=repo_id, exist_ok=True)
+ api.upload_folder(
+ folder_path=path,
+ repo_id=repo_id,
+ repo_type="model",
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Convert Hugging Face model to MLX format"
+ )
+ parser.add_argument(
+ "--hf-path",
+ type=str,
+ help="Path to the Hugging Face model.",
+ )
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
- help="The path to save the MLX model.",
+ help="Path to save the MLX model.",
)
parser.add_argument(
"-q",
@@ -67,26 +134,39 @@ def convert():
type=int,
default=4,
)
+ parser.add_argument(
+ "--dtype",
+ help="Type to save the parameters, ignored if -q is given.",
+ type=str,
+ choices=["float16", "bfloat16", "float32"],
+ default="float16",
+ )
+ parser.add_argument(
+ "--upload-name",
+ help="The name of model to upload to Hugging Face MLX Community",
+ type=str,
+ default=None,
+ )
+
args = parser.parse_args()
+ print("[INFO] Loading")
+ weights, config, tokenizer = fetch_from_hub(args.hf_path)
+
+ dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
+ weights = {k: v.astype(dtype) for k, v in weights.items()}
+ if args.quantize:
+ print("[INFO] Quantizing")
+ weights, config = quantize(weights, config, args)
+
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
-
- model = AutoModelForCausalLM.from_pretrained(
- "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
- )
- state_dict = model.state_dict()
- weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
- params = {}
- if args.quantize:
- print("[INFO] Quantizing")
- weights, params = quantize(weights, params, args)
-
- np.savez(str(mlx_path / "weights.npz"), **weights)
+ shards = make_shards(weights)
+ for i, shard in enumerate(shards):
+ mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
+ tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
- params["model_type"] = "phi2"
- json.dump(params, fid, indent=4)
+ json.dump(config, fid, indent=4)
-
-if __name__ == "__main__":
- convert()
+ if args.upload_name is not None:
+ upload_to_hub(mlx_path, args.upload_name, args.hf_path)
diff --git a/llms/phi2/generate.py b/llms/phi2/generate.py
new file mode 100644
index 00000000..6ba63ce3
--- /dev/null
+++ b/llms/phi2/generate.py
@@ -0,0 +1,91 @@
+# Copyright © 2023 Apple Inc.
+
+import argparse
+import time
+
+import mlx.core as mx
+import phi2
+import transformers
+
+
+def generate(
+ model: phi2.Model,
+ tokenizer: transformers.AutoTokenizer,
+ prompt: str,
+ max_tokens: int,
+ temp: float = 0.0,
+):
+ print("[INFO] Generating with Phi-2...", flush=True)
+ print(args.prompt, end="", flush=True)
+ prompt = tokenizer(
+ prompt,
+ return_tensors="np",
+ return_attention_mask=False,
+ )[
+ "input_ids"
+ ][0]
+ prompt = mx.array(prompt)
+
+ tic = time.time()
+ tokens = []
+ skip = 0
+ for token, n in zip(
+ phi2.generate(prompt, model, args.temp),
+ range(args.max_tokens),
+ ):
+ if token == tokenizer.eos_token_id:
+ break
+
+ if n == 0:
+ prompt_time = time.time() - tic
+ tic = time.time()
+
+ tokens.append(token.item())
+ # if (n + 1) % 10 == 0:
+ s = tokenizer.decode(tokens)
+ print(s[skip:], end="", flush=True)
+ skip = len(s)
+ print(tokenizer.decode(tokens)[skip:], flush=True)
+ gen_time = time.time() - tic
+ print("=" * 10)
+ if len(tokens) == 0:
+ print("No tokens generated for this prompt")
+ return
+ prompt_tps = prompt.size / prompt_time
+ gen_tps = (len(tokens) - 1) / gen_time
+ print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
+ print(f"Generation: {gen_tps:.3f} tokens-per-sec")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="inference script")
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="mlx_model",
+ help="The path to the local model directory or Hugging Face repo.",
+ )
+ parser.add_argument(
+ "--prompt",
+ help="The message to be processed by the model",
+ default="Write a detailed analogy between mathematics and a lighthouse.",
+ )
+ parser.add_argument(
+ "--max-tokens",
+ "-m",
+ type=int,
+ default=100,
+ help="Maximum number of tokens to generate",
+ )
+ parser.add_argument(
+ "--temp",
+ help="The sampling temperature.",
+ type=float,
+ default=0.0,
+ )
+ parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
+
+ args = parser.parse_args()
+ mx.random.seed(args.seed)
+ model, tokenizer = phi2.load(args.model)
+ generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)
diff --git a/llms/phi2/phi2.py b/llms/phi2/phi2.py
index f824549d..8154acf3 100644
--- a/llms/phi2/phi2.py
+++ b/llms/phi2/phi2.py
@@ -1,4 +1,6 @@
import argparse
+import glob
+import inspect
import json
import math
from dataclasses import dataclass
@@ -7,6 +9,7 @@ from typing import Optional
import mlx.core as mx
import mlx.nn as nn
+from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@@ -20,6 +23,16 @@ class ModelArgs:
num_layers: int = 32
rotary_dim: int = 32
+ @classmethod
+ def from_dict(cls, params):
+ return cls(
+ **{
+ k: v
+ for k, v in params.items()
+ if k in inspect.signature(cls).parameters
+ }
+ )
+
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
@@ -75,6 +88,17 @@ class RoPEAttention(nn.Module):
return self.out_proj(values_hat), (keys, values)
+class MLP(nn.Module):
+ def __init__(self, dim, hidden_dim):
+ super().__init__()
+ self.fc1 = nn.Linear(dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, dim)
+ self.act = nn.GELU(approx="precise")
+
+ def __call__(self, x) -> mx.array:
+ return self.fc2(self.act(self.fc1(x)))
+
+
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
@@ -82,23 +106,23 @@ class ParallelBlock(nn.Module):
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims)
- self.fc1 = nn.Linear(dims, mlp_dims)
- self.fc2 = nn.Linear(mlp_dims, dims)
- self.act = nn.GELU(approx="precise")
+ self.mlp = MLP(dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h, cache = self.mixer(h, mask, cache)
- ff_h = self.fc2(self.act(self.fc1(h)))
+ ff_h = self.mlp(h)
return attn_h + ff_h + x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
+ self.embd = Embd(config)
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
def __call__(self, x, mask, cache):
+ x = self.embd(x)
if cache is None:
cache = [None] * len(self.h)
@@ -107,8 +131,18 @@ class TransformerDecoder(nn.Module):
return x, cache
+class Embd(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.wte = nn.Embedding(config.num_vocab, config.model_dim)
+
+ def __call__(self, x):
+ return self.wte(x)
+
+
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
self.ln = LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
@@ -116,20 +150,18 @@ class OutputHead(nn.Module):
return self.linear(self.ln(inputs))
-class Phi2(nn.Module):
+class Model(nn.Module):
def __init__(self, config: ModelArgs):
- self.wte = nn.Embedding(config.num_vocab, config.model_dim)
+ super().__init__()
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
def __call__(
self,
- inputs: mx.array,
+ x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
- x = self.wte(inputs)
-
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
@@ -139,104 +171,55 @@ class Phi2(nn.Module):
return self.lm_head(y), cache
-def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
+def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
- logits, cache = model(prompt)
- y = sample(logits[:, -1, :])
- yield y
-
+ y = prompt
+ cache = None
while True:
- logits, cache = model(y[:, None], cache=cache)
- y = sample(logits.squeeze(1))
+ logits, cache = model(y[None], cache=cache)
+ logits = logits[:, -1, :]
+ y = sample(logits)
yield y
-def load_model(model_path: str):
- model = Phi2(ModelArgs())
- model_path = Path(model_path)
+def load(path_or_hf_repo: str):
+ # If the path exists, it will try to load model form it
+ # otherwise download and cache from the hf_repo and cache
+ model_path = Path(path_or_hf_repo)
+ if not model_path.exists():
+ model_path = Path(
+ snapshot_download(
+ repo_id=path_or_hf_repo,
+ allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
+ )
+ )
+
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
- config.pop("model_type", None)
- quantization = config.pop("quantization", None)
- weights = mx.load(str(model_path / "weights.npz"))
- weights = tree_unflatten(list(weights.items()))
+ quantization = config.get("quantization", None)
+ model_args = ModelArgs.from_dict(config)
+
+ weight_files = glob.glob(str(model_path / "*.safetensors"))
+ if len(weight_files) == 0:
+ raise FileNotFoundError("No safetensors found in {}".format(model_path))
+
+ weights = {}
+ for wf in weight_files:
+ weights.update(mx.load(wf).items())
+
+ model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
- model.update(weights)
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
+ model.load_weights(list(weights.items()))
+
+ mx.eval(model.parameters())
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ )
return model, tokenizer
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Phi-2 inference script")
- parser.add_argument(
- "--model-path",
- type=str,
- default="mlx_model",
- help="The path to the model weights",
- )
- parser.add_argument(
- "--prompt",
- help="The message to be processed by the model",
- default="Write a detailed analogy between mathematics and a lighthouse.",
- )
- parser.add_argument(
- "--max-tokens",
- "-m",
- type=int,
- default=100,
- help="Maximum number of tokens to generate",
- )
- parser.add_argument(
- "--temp",
- help="The sampling temperature.",
- type=float,
- default=0.0,
- )
- parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
- args = parser.parse_args()
-
- mx.random.seed(args.seed)
-
- model, tokenizer = load_model(args.model_path)
-
- prompt = tokenizer(
- args.prompt,
- return_tensors="np",
- return_attention_mask=False,
- )["input_ids"]
-
- prompt = mx.array(prompt)
-
- print("[INFO] Generating with Phi-2...", flush=True)
- print(args.prompt, end="", flush=True)
-
- tokens = []
- for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
- tokens.append(token)
-
- if (len(tokens) % 10) == 0:
- mx.eval(tokens)
- eos_index = next(
- (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
- None,
- )
-
- if eos_index is not None:
- tokens = tokens[:eos_index]
-
- s = tokenizer.decode([t.item() for t in tokens])
- print(s, end="", flush=True)
- tokens = []
- if eos_index is not None:
- break
-
- mx.eval(tokens)
- s = tokenizer.decode([t.item() for t in tokens])
- print(s, flush=True)