diff --git a/llms/phi2/README.md b/llms/phi2/README.md index c79dd5e8..086cd17e 100644 --- a/llms/phi2/README.md +++ b/llms/phi2/README.md @@ -7,63 +7,52 @@ GPT-4 outputs and clean web text. Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit precision. -## Setup +### Setup -Download and convert the model: - -```sh -python convert.py -``` - -To generate a 4-bit quantized model use the `-q` flag: +Install the dependencies: ``` -python convert.py -q +pip install -r requirements.txt ``` -By default, the conversion script will make the directory `mlx_model` and save -the converted `weights.npz`, and `config.json` there. - -> [!TIP] Alternatively, you can also download a few converted checkpoints from -> the [MLX Community](https://huggingface.co/mlx-community) organization on -> Hugging Face and skip the conversion step. - - -## Generate - -To generate text with the default prompt: - -```sh -python phi2.py +### Run ``` - -Should give the output: +python generate.py --model --prompt "hello" +``` +For example: ``` -Answer: Mathematics is like a lighthouse that guides us through the darkness of -uncertainty. Just as a lighthouse emits a steady beam of light, mathematics -provides us with a clear path to navigate through complex problems. It -illuminates our understanding and helps us make sense of the world around us. +python generate.py --model microsoft/phi-2 --prompt "hello" +``` +The `` should be either a path to a local directory or a Hugging +Face repo with weights stored in `safetensors` format. If you use a repo from +the Hugging Face Hub, then the model will be downloaded and cached the first +time you run it. -Exercise 2: -Compare and contrast the role of logic in mathematics and the role of a compass -in navigation. +Run `python generate.py --help` to see all the options. -Answer: Logic in mathematics is like a compass in navigation. It helps +### Convert new models + +You can convert (change the data type or quantize) models using the +`convert.py` script. This script takes a Hugging Face repo as input and outputs +a model directory (which you can optionally also upload to Hugging Face). + +For example, to make 4-bit quantized a model, run: + +``` +python convert.py --hf-path -q ``` -To use your own prompt: +For more options run: -```sh -python phi2.py --prompt --max-tokens +``` +python convert.py --help ``` -To see a list of options run: - -```sh -python phi2.py --help -``` +You can upload new models to the [Hugging Face MLX +Community](https://huggingface.co/mlx-community) by specifying `--upload-name`` +to `convert.py`. [^1]: For more details on the model see the [blog post]( https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) -and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2) +and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2) \ No newline at end of file diff --git a/llms/phi2/convert.py b/llms/phi2/convert.py index 0cb5e519..4cac6e82 100644 --- a/llms/phi2/convert.py +++ b/llms/phi2/convert.py @@ -1,23 +1,43 @@ import argparse import copy +import glob import json from pathlib import Path import mlx.core as mx import mlx.nn as nn -import numpy as np -from mlx.utils import tree_flatten, tree_map, tree_unflatten -from phi2 import ModelArgs, Phi2 -from transformers import AutoModelForCausalLM +import transformers +from huggingface_hub import snapshot_download +from mlx.utils import tree_flatten +from phi2 import Model, ModelArgs + + +def fetch_from_hub(hf_path: str): + model_path = snapshot_download( + repo_id=hf_path, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + weight_files = glob.glob(f"{model_path}/*.safetensors") + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + config = transformers.AutoConfig.from_pretrained(hf_path, trust_remote_code=True) + tokenizer = transformers.AutoTokenizer.from_pretrained( + hf_path, + ) + return weights, config.to_dict(), tokenizer def quantize(weights, config, args): quantized_config = copy.deepcopy(config) # Load the model: - model = Phi2(ModelArgs()) - weights = tree_map(mx.array, weights) - model.update(tree_unflatten(list(weights.items()))) + model = Model(ModelArgs.from_dict(config)) + model.load_weights(list(weights.items())) # Quantize the model: nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) @@ -32,22 +52,69 @@ def quantize(weights, config, args): return quantized_weights, quantized_config -def replace_key(key: str) -> str: - if "wte.weight" in key: - key = "wte.weight" - - if ".mlp" in key: - key = key.replace(".mlp", "") - return key +def make_shards(weights: dict, max_file_size_gibibyte: int = 15): + max_file_size_bytes = max_file_size_gibibyte << 30 + shards = [] + shard, shard_size = {}, 0 + for k, v in weights.items(): + estimated_size = v.size * v.dtype.size + if shard_size + estimated_size > max_file_size_bytes: + shards.append(shard) + shard, shard_size = {}, 0 + shard[k] = v + shard_size += estimated_size + shards.append(shard) + return shards -def convert(): - parser = argparse.ArgumentParser(description="Convert Phi-2 weights to MLX") +def upload_to_hub(path: str, name: str, hf_path: str): + import os + + from huggingface_hub import HfApi, ModelCard, logging + + repo_id = f"mlx-community/{name}" + + card = ModelCard.load(hf_path) + card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] + card.text = f""" +# {name} +This model was converted to MLX format from [`{hf_path}`](). +Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. +## Use with mlx +```bash +pip install mlx +git clone https://github.com/ml-explore/mlx-examples.git +cd mlx-examples/llms/hf_llm +python generate.py --model {repo_id} --prompt "My name is" +``` +""" + card.save(os.path.join(path, "README.md")) + + logging.set_verbosity_info() + + api = HfApi() + api.create_repo(repo_id=repo_id, exist_ok=True) + api.upload_folder( + folder_path=path, + repo_id=repo_id, + repo_type="model", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert Hugging Face model to MLX format" + ) + parser.add_argument( + "--hf-path", + type=str, + help="Path to the Hugging Face model.", + ) parser.add_argument( "--mlx-path", type=str, default="mlx_model", - help="The path to save the MLX model.", + help="Path to save the MLX model.", ) parser.add_argument( "-q", @@ -67,26 +134,39 @@ def convert(): type=int, default=4, ) + parser.add_argument( + "--dtype", + help="Type to save the parameters, ignored if -q is given.", + type=str, + choices=["float16", "bfloat16", "float32"], + default="float16", + ) + parser.add_argument( + "--upload-name", + help="The name of model to upload to Hugging Face MLX Community", + type=str, + default=None, + ) + args = parser.parse_args() + print("[INFO] Loading") + weights, config, tokenizer = fetch_from_hub(args.hf_path) + + dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) + weights = {k: v.astype(dtype) for k, v in weights.items()} + if args.quantize: + print("[INFO] Quantizing") + weights, config = quantize(weights, config, args) + mlx_path = Path(args.mlx_path) mlx_path.mkdir(parents=True, exist_ok=True) - - model = AutoModelForCausalLM.from_pretrained( - "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True - ) - state_dict = model.state_dict() - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - params = {} - if args.quantize: - print("[INFO] Quantizing") - weights, params = quantize(weights, params, args) - - np.savez(str(mlx_path / "weights.npz"), **weights) + shards = make_shards(weights) + for i, shard in enumerate(shards): + mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard) + tokenizer.save_pretrained(mlx_path) with open(mlx_path / "config.json", "w") as fid: - params["model_type"] = "phi2" - json.dump(params, fid, indent=4) + json.dump(config, fid, indent=4) - -if __name__ == "__main__": - convert() + if args.upload_name is not None: + upload_to_hub(mlx_path, args.upload_name, args.hf_path) diff --git a/llms/phi2/generate.py b/llms/phi2/generate.py new file mode 100644 index 00000000..6ba63ce3 --- /dev/null +++ b/llms/phi2/generate.py @@ -0,0 +1,91 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import time + +import mlx.core as mx +import phi2 +import transformers + + +def generate( + model: phi2.Model, + tokenizer: transformers.AutoTokenizer, + prompt: str, + max_tokens: int, + temp: float = 0.0, +): + print("[INFO] Generating with Phi-2...", flush=True) + print(args.prompt, end="", flush=True) + prompt = tokenizer( + prompt, + return_tensors="np", + return_attention_mask=False, + )[ + "input_ids" + ][0] + prompt = mx.array(prompt) + + tic = time.time() + tokens = [] + skip = 0 + for token, n in zip( + phi2.generate(prompt, model, args.temp), + range(args.max_tokens), + ): + if token == tokenizer.eos_token_id: + break + + if n == 0: + prompt_time = time.time() - tic + tic = time.time() + + tokens.append(token.item()) + # if (n + 1) % 10 == 0: + s = tokenizer.decode(tokens) + print(s[skip:], end="", flush=True) + skip = len(s) + print(tokenizer.decode(tokens)[skip:], flush=True) + gen_time = time.time() - tic + print("=" * 10) + if len(tokens) == 0: + print("No tokens generated for this prompt") + return + prompt_tps = prompt.size / prompt_time + gen_tps = (len(tokens) - 1) / gen_time + print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {gen_tps:.3f} tokens-per-sec") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="inference script") + parser.add_argument( + "--model", + type=str, + default="mlx_model", + help="The path to the local model directory or Hugging Face repo.", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="Write a detailed analogy between mathematics and a lighthouse.", + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + + args = parser.parse_args() + mx.random.seed(args.seed) + model, tokenizer = phi2.load(args.model) + generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) diff --git a/llms/phi2/phi2.py b/llms/phi2/phi2.py index f824549d..8154acf3 100644 --- a/llms/phi2/phi2.py +++ b/llms/phi2/phi2.py @@ -1,4 +1,6 @@ import argparse +import glob +import inspect import json import math from dataclasses import dataclass @@ -7,6 +9,7 @@ from typing import Optional import mlx.core as mx import mlx.nn as nn +from huggingface_hub import snapshot_download from mlx.utils import tree_unflatten from transformers import AutoTokenizer @@ -20,6 +23,16 @@ class ModelArgs: num_layers: int = 32 rotary_dim: int = 32 + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + class LayerNorm(nn.LayerNorm): def __call__(self, x: mx.array) -> mx.array: @@ -75,6 +88,17 @@ class RoPEAttention(nn.Module): return self.out_proj(values_hat), (keys, values) +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, dim) + self.act = nn.GELU(approx="precise") + + def __call__(self, x) -> mx.array: + return self.fc2(self.act(self.fc1(x))) + + class ParallelBlock(nn.Module): def __init__(self, config: ModelArgs): super().__init__() @@ -82,23 +106,23 @@ class ParallelBlock(nn.Module): mlp_dims = dims * 4 self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) self.ln = LayerNorm(dims) - self.fc1 = nn.Linear(dims, mlp_dims) - self.fc2 = nn.Linear(mlp_dims, dims) - self.act = nn.GELU(approx="precise") + self.mlp = MLP(dims, mlp_dims) def __call__(self, x, mask, cache): h = self.ln(x) attn_h, cache = self.mixer(h, mask, cache) - ff_h = self.fc2(self.act(self.fc1(h))) + ff_h = self.mlp(h) return attn_h + ff_h + x, cache class TransformerDecoder(nn.Module): def __init__(self, config: ModelArgs): super().__init__() + self.embd = Embd(config) self.h = [ParallelBlock(config) for i in range(config.num_layers)] def __call__(self, x, mask, cache): + x = self.embd(x) if cache is None: cache = [None] * len(self.h) @@ -107,8 +131,18 @@ class TransformerDecoder(nn.Module): return x, cache +class Embd(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.wte = nn.Embedding(config.num_vocab, config.model_dim) + + def __call__(self, x): + return self.wte(x) + + class OutputHead(nn.Module): def __init__(self, config: ModelArgs) -> None: + super().__init__() self.ln = LayerNorm(config.model_dim) self.linear = nn.Linear(config.model_dim, config.num_vocab) @@ -116,20 +150,18 @@ class OutputHead(nn.Module): return self.linear(self.ln(inputs)) -class Phi2(nn.Module): +class Model(nn.Module): def __init__(self, config: ModelArgs): - self.wte = nn.Embedding(config.num_vocab, config.model_dim) + super().__init__() self.transformer = TransformerDecoder(config) self.lm_head = OutputHead(config) def __call__( self, - inputs: mx.array, + x: mx.array, mask: mx.array = None, cache: mx.array = None, ) -> tuple[mx.array, mx.array]: - x = self.wte(inputs) - mask = None if x.shape[1] > 1: mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) @@ -139,104 +171,55 @@ class Phi2(nn.Module): return self.lm_head(y), cache -def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): +def generate(prompt: mx.array, model: Model, temp: float = 0.0): def sample(logits): if temp == 0: return mx.argmax(logits, axis=-1) else: return mx.random.categorical(logits * (1 / temp)) - logits, cache = model(prompt) - y = sample(logits[:, -1, :]) - yield y - + y = prompt + cache = None while True: - logits, cache = model(y[:, None], cache=cache) - y = sample(logits.squeeze(1)) + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits) yield y -def load_model(model_path: str): - model = Phi2(ModelArgs()) - model_path = Path(model_path) +def load(path_or_hf_repo: str): + # If the path exists, it will try to load model form it + # otherwise download and cache from the hf_repo and cache + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + ) + with open(model_path / "config.json", "r") as f: config = json.loads(f.read()) - config.pop("model_type", None) - quantization = config.pop("quantization", None) - weights = mx.load(str(model_path / "weights.npz")) - weights = tree_unflatten(list(weights.items())) + quantization = config.get("quantization", None) + model_args = ModelArgs.from_dict(config) + + weight_files = glob.glob(str(model_path / "*.safetensors")) + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + model = Model(model_args) if quantization is not None: nn.QuantizedLinear.quantize_module(model, **quantization) - model.update(weights) - tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + model.load_weights(list(weights.items())) + + mx.eval(model.parameters()) + tokenizer = AutoTokenizer.from_pretrained( + model_path, + ) return model, tokenizer - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Phi-2 inference script") - parser.add_argument( - "--model-path", - type=str, - default="mlx_model", - help="The path to the model weights", - ) - parser.add_argument( - "--prompt", - help="The message to be processed by the model", - default="Write a detailed analogy between mathematics and a lighthouse.", - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temp", - help="The sampling temperature.", - type=float, - default=0.0, - ) - parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") - args = parser.parse_args() - - mx.random.seed(args.seed) - - model, tokenizer = load_model(args.model_path) - - prompt = tokenizer( - args.prompt, - return_tensors="np", - return_attention_mask=False, - )["input_ids"] - - prompt = mx.array(prompt) - - print("[INFO] Generating with Phi-2...", flush=True) - print(args.prompt, end="", flush=True) - - tokens = [] - for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): - tokens.append(token) - - if (len(tokens) % 10) == 0: - mx.eval(tokens) - eos_index = next( - (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), - None, - ) - - if eos_index is not None: - tokens = tokens[:eos_index] - - s = tokenizer.decode([t.item() for t in tokens]) - print(s, end="", flush=True) - tokens = [] - if eos_index is not None: - break - - mx.eval(tokens) - s = tokenizer.decode([t.item() for t in tokens]) - print(s, flush=True)