From a2402116ae6201264f683259046741ce8fac39e7 Mon Sep 17 00:00:00 2001 From: Anchen Date: Thu, 11 Jan 2024 12:29:12 -0800 Subject: [PATCH] refactor(hf_llm): moving phi2 example into hf_llm (#293) * refactor: moving phi2 example into hf_llm * chore: clean up * chore: update phi2 model args so it can load args from config * fix phi2 + nits + readme * allow any HF repo, update README * fix bug in llama --------- Co-authored-by: Awni Hannun --- llms/hf_llm/.gitignore | 1 + llms/hf_llm/README.md | 21 +++- llms/hf_llm/convert.py | 179 +++++++++++++------------- llms/hf_llm/generate.py | 97 +++++++------- llms/hf_llm/models/.gitignore | 0 llms/hf_llm/models/__init__.py | 0 llms/hf_llm/models/base.py | 15 +++ llms/hf_llm/models/llama.py | 202 +++++++++++++++++++++++++++++ llms/hf_llm/models/phi2.py | 138 ++++++++++++++++++++ llms/hf_llm/utils.py | 141 +++++++++++++++++++++ llms/phi2/README.md | 58 --------- llms/phi2/convert.py | 172 ------------------------- llms/phi2/generate.py | 91 -------------- llms/phi2/phi2.py | 224 --------------------------------- llms/phi2/requirements.txt | 5 - 15 files changed, 647 insertions(+), 697 deletions(-) create mode 100644 llms/hf_llm/.gitignore create mode 100644 llms/hf_llm/models/.gitignore create mode 100644 llms/hf_llm/models/__init__.py create mode 100644 llms/hf_llm/models/base.py create mode 100644 llms/hf_llm/models/llama.py create mode 100644 llms/hf_llm/models/phi2.py create mode 100644 llms/hf_llm/utils.py delete mode 100644 llms/phi2/README.md delete mode 100644 llms/phi2/convert.py delete mode 100644 llms/phi2/generate.py delete mode 100644 llms/phi2/phi2.py delete mode 100644 llms/phi2/requirements.txt diff --git a/llms/hf_llm/.gitignore b/llms/hf_llm/.gitignore new file mode 100644 index 00000000..9666d735 --- /dev/null +++ b/llms/hf_llm/.gitignore @@ -0,0 +1 @@ +mlx_model \ No newline at end of file diff --git a/llms/hf_llm/README.md b/llms/hf_llm/README.md index b7762be3..e2734adb 100644 --- a/llms/hf_llm/README.md +++ b/llms/hf_llm/README.md @@ -35,7 +35,7 @@ Run `python generate.py --help` to see all the options. ### Models -The example supports Hugging Face format Mistral and Llama-style models. If the +The example supports Hugging Face format Mistral, Llama, and Phi-2 style models. If the model you want to run is not supported, file an [issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, submit a pull request. @@ -47,11 +47,13 @@ Here are a few examples of Hugging Face models that work with this example: - [TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T) - [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) - [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat) +- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) Most -[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending) +[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), +[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending), and -[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending) +[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending) style models should work out of the box. ### Convert new models @@ -72,6 +74,13 @@ For more options run: python convert.py --help ``` -You can upload new models to the [Hugging Face MLX -Community](https://huggingface.co/mlx-community) by specifying `--upload-name` -to `convert.py`. +You can upload new models to Hugging Face by specifying `--upload-repo` to +`convert.py`. For example, to upload a quantized Mistral-7B model to the +[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do: + +``` +python convert.py \ + --hf-path mistralai/Mistral-7B-v0.1 \ + -q \ + --upload mlx-community/my-4bit-mistral \ +``` diff --git a/llms/hf_llm/convert.py b/llms/hf_llm/convert.py index 3523c7f2..488c7213 100644 --- a/llms/hf_llm/convert.py +++ b/llms/hf_llm/convert.py @@ -1,52 +1,95 @@ -# Copyright © 2023 Apple Inc. - import argparse import copy import glob import json from pathlib import Path +from typing import Dict, Tuple import mlx.core as mx import mlx.nn as nn import transformers -from huggingface_hub import snapshot_download from mlx.utils import tree_flatten -from models import Model, ModelArgs +from utils import get_model_path, load + +MAX_FILE_SIZE_GB = 15 -def fetch_from_hub(model_path: str, local: bool): - if not local: - model_path = snapshot_download( - repo_id=model_path, - allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], - ) +def configure_parser() -> argparse.ArgumentParser: + """ + Configures and returns the argument parser for the script. + + Returns: + argparse.ArgumentParser: Configured argument parser. + """ + parser = argparse.ArgumentParser( + description="Convert Hugging Face model to MLX format" + ) + + parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.") + parser.add_argument( + "--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model." + ) + parser.add_argument( + "-q", "--quantize", help="Generate a quantized model.", action="store_true" + ) + parser.add_argument( + "--q-group-size", help="Group size for quantization.", type=int, default=64 + ) + parser.add_argument( + "--q-bits", help="Bits per weight for quantization.", type=int, default=4 + ) + parser.add_argument( + "--dtype", + help="Type to save the parameters, ignored if -q is given.", + type=str, + choices=["float16", "bfloat16", "float32"], + default="float16", + ) + parser.add_argument( + "--upload-repo", + help="The Hugging Face repo to upload the model to.", + type=str, + default=None, + ) + return parser + + +def fetch_from_hub( + model_path: str, +) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]: + model_path = get_model_path(model_path) weight_files = glob.glob(f"{model_path}/*.safetensors") - if len(weight_files) == 0: - raise FileNotFoundError("No safetensors found in {}".format(model_path)) + if not weight_files: + raise FileNotFoundError(f"No safetensors found in {model_path}") weights = {} for wf in weight_files: weights.update(mx.load(wf).items()) config = transformers.AutoConfig.from_pretrained(model_path) - tokenizer = transformers.AutoTokenizer.from_pretrained( - model_path, - ) + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + return weights, config.to_dict(), tokenizer -def quantize(weights, config, args): - quantized_config = copy.deepcopy(config) +def quantize(weights: dict, config: dict, args: argparse.Namespace) -> tuple: + """ + Applies quantization to the model weights. - # Load the model: - model = Model(ModelArgs.from_dict(config)) + Args: + weights (dict): Model weights. + config (dict): Model configuration. + args (argparse.Namespace): Command-line arguments. + + Returns: + tuple: Tuple containing quantized weights and config. + """ + quantized_config = copy.deepcopy(config) + model, _ = load(args.hf_path) model.load_weights(list(weights.items())) - # Quantize the model: nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) - - # Update the config: quantized_config["quantization"] = { "group_size": args.q_group_size, "bits": args.q_bits, @@ -56,8 +99,18 @@ def quantize(weights, config, args): return quantized_weights, quantized_config -def make_shards(weights: dict, max_file_size_gibibyte: int = 15): - max_file_size_bytes = max_file_size_gibibyte << 30 +def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: + """ + Splits the weights into smaller shards. + + Args: + weights (dict): Model weights. + max_file_size_gb (int): Maximum size of each shard in gigabytes. + + Returns: + list: List of weight shards. + """ + max_file_size_bytes = max_file_size_gb << 30 shards = [] shard, shard_size = {}, 0 for k, v in weights.items(): @@ -71,17 +124,23 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15): return shards -def upload_to_hub(path: str, name: str, hf_path: str): +def upload_to_hub(path: str, upload_repo: str, hf_path: str): + """ + Uploads the model to Hugging Face hub. + + Args: + path (str): Local path to the model. + upload_repo (str): Name of the HF repo to upload to. + hf_path (str): Path to the original Hugging Face model. + """ import os from huggingface_hub import HfApi, ModelCard, logging - repo_id = f"mlx-community/{name}" - card = ModelCard.load(hf_path) card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] card.text = f""" -# {name} +# {upload_repo} This model was converted to MLX format from [`{hf_path}`](). Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. ## Use with mlx @@ -97,72 +156,20 @@ python generate.py --model {repo_id} --prompt "My name is" logging.set_verbosity_info() api = HfApi() - api.create_repo(repo_id=repo_id, exist_ok=True) + api.create_repo(repo_id=upload_repo, exist_ok=True) api.upload_folder( folder_path=path, - repo_id=repo_id, + repo_id=upload_repo, repo_type="model", ) if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Convert Hugging Face model to MLX format" - ) - parser.add_argument( - "--hf-path", - type=str, - help="Path to the Hugging Face model.", - ) - parser.add_argument( - "--mlx-path", - type=str, - default="mlx_model", - help="Path to save the MLX model.", - ) - parser.add_argument( - "-q", - "--quantize", - help="Generate a quantized model.", - action="store_true", - ) - parser.add_argument( - "--q-group-size", - help="Group size for quantization.", - type=int, - default=64, - ) - parser.add_argument( - "--q-bits", - help="Bits per weight for quantization.", - type=int, - default=4, - ) - parser.add_argument( - "--dtype", - help="Type to save the parameters, ignored if -q is given.", - type=str, - choices=["float16", "bfloat16", "float32"], - default="float16", - ) - parser.add_argument( - "--upload-name", - help="The name of model to upload to Hugging Face MLX Community", - type=str, - default=None, - ) - parser.add_argument( - "-l", - "--local", - action="store_true", - help="Whether the hf-path points to a local filesystem.", - default=False, - ) - + parser = configure_parser() args = parser.parse_args() print("[INFO] Loading") - weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local) + weights, config, tokenizer = fetch_from_hub(args.hf_path) dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} @@ -179,5 +186,5 @@ if __name__ == "__main__": with open(mlx_path / "config.json", "w") as fid: json.dump(config, fid, indent=4) - if args.upload_name is not None and not args.local: - upload_to_hub(mlx_path, args.upload_name, args.hf_path) + if args.upload_repo is not None: + upload_to_hub(mlx_path, args.upload_repo, args.hf_path) diff --git a/llms/hf_llm/generate.py b/llms/hf_llm/generate.py index 460a1764..1e906c89 100644 --- a/llms/hf_llm/generate.py +++ b/llms/hf_llm/generate.py @@ -1,43 +1,58 @@ -# Copyright © 2023 Apple Inc. - import argparse import time import mlx.core as mx -import models -import transformers +from utils import generate, load + +DEFAULT_MODEL_PATH = "mlx_model" +DEFAULT_PROMPT = "hello" +DEFAULT_MAX_TOKENS = 100 +DEFAULT_TEMP = 0.6 +DEFAULT_SEED = 0 -def generate( - model: models.Model, - tokenizer: transformers.AutoTokenizer, - prompt: str, - max_tokens: int, - temp: float = 0.0, -): - prompt = tokenizer( - prompt, - return_tensors="np", - return_attention_mask=False, - )[ - "input_ids" - ][0] +def setup_arg_parser(): + """Set up and return the argument parser.""" + parser = argparse.ArgumentParser(description="LLM inference script") + parser.add_argument( + "--model", + type=str, + default="mlx_model", + help="The path to the local model directory or Hugging Face repo.", + ) + parser.add_argument( + "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model" + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=DEFAULT_MAX_TOKENS, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" + ) + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + return parser + + +def main(args): + mx.random.seed(args.seed) + model, tokenizer = load(args.model) + print("=" * 10) + print("Prompt:", args.prompt) + prompt = tokenizer.encode(args.prompt) prompt = mx.array(prompt) - tic = time.time() tokens = [] skip = 0 - for token, n in zip( - models.generate(prompt, model, temp), - range(max_tokens), - ): + for token, n in zip(generate(prompt, model, args.temp), range(args.max_tokens)): if token == tokenizer.eos_token_id: break - if n == 0: prompt_time = time.time() - tic tic = time.time() - tokens.append(token.item()) s = tokenizer.decode(tokens) print(s[skip:], end="", flush=True) @@ -55,34 +70,6 @@ def generate( if __name__ == "__main__": - parser = argparse.ArgumentParser(description="inference script") - parser.add_argument( - "--model", - type=str, - default="mlx_model", - help="The path to the local model directory or Hugging Face repo.", - ) - parser.add_argument( - "--prompt", - help="The message to be processed by the model", - default="In the beginning the Universe was created.", - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temp", - help="The sampling temperature.", - type=float, - default=0.0, - ) - parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") - + parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) - model, tokenizer = models.load(args.model) - generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) + main(args) diff --git a/llms/hf_llm/models/.gitignore b/llms/hf_llm/models/.gitignore new file mode 100644 index 00000000..e69de29b diff --git a/llms/hf_llm/models/__init__.py b/llms/hf_llm/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llms/hf_llm/models/base.py b/llms/hf_llm/models/base.py new file mode 100644 index 00000000..d1ea0b2c --- /dev/null +++ b/llms/hf_llm/models/base.py @@ -0,0 +1,15 @@ +import inspect +from dataclasses import dataclass + + +@dataclass +class BaseModelArgs: + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) diff --git a/llms/hf_llm/models/llama.py b/llms/hf_llm/models/llama.py new file mode 100644 index 00000000..ee026363 --- /dev/null +++ b/llms/hf_llm/models/llama.py @@ -0,0 +1,202 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + model_type: str = None + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.repeats = n_heads // n_kv_heads + + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + rope_scale = ( + 1 / args.rope_scaling["factor"] + if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" + else 1 + ) + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + if self.repeats > 1: + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model = LlamaModel(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + return self.lm_head(out), cache diff --git a/llms/hf_llm/models/phi2.py b/llms/hf_llm/models/phi2.py new file mode 100644 index 00000000..51b5e390 --- /dev/null +++ b/llms/hf_llm/models/phi2.py @@ -0,0 +1,138 @@ +import math +from dataclasses import dataclass + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + n_positions: int = 2048 + vocab_size: int = 51200 + n_embd: int = 2560 + n_head: int = 32 + n_layer: int = 32 + rotary_dim: int = 32 + + +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + + +class RoPEAttention(nn.Module): + def __init__(self, dims: int, n_head: int, rotary_dim: int): + super().__init__() + + self.n_head = n_head + + self.q_proj = nn.Linear(dims, dims) + self.k_proj = nn.Linear(dims, dims) + self.v_proj = nn.Linear(dims, dims) + self.dense = nn.Linear(dims, dims) + + self.rope = nn.RoPE(rotary_dim, traditional=False) + + def __call__(self, x, mask=None, cache=None): + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Extract some shapes + n_head = self.n_head + B, L, D = queries.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) + + # Add RoPE to the queries and keys and combine them with the cache + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + queries = queries.astype(mx.float32) + keys = keys.astype(mx.float32) + + # Finally perform the attention computation + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + + scores = mx.softmax(scores, axis=-1).astype(values.dtype) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.dense(values_hat), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, dim) + self.act = nn.GELU(approx="precise") + + def __call__(self, x) -> mx.array: + return self.fc2(self.act(self.fc1(x))) + + +class ParallelBlock(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + dims = config.n_embd + mlp_dims = dims * 4 + self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim) + self.input_layernorm = LayerNorm(dims) + self.mlp = MLP(dims, mlp_dims) + + def __call__(self, x, mask, cache): + h = self.input_layernorm(x) + attn_h, cache = self.self_attn(h, mask, cache) + ff_h = self.mlp(h) + return attn_h + ff_h + x, cache + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) + self.layers = [ParallelBlock(config) for i in range(config.n_layer)] + self.final_layernorm = LayerNorm(config.n_embd) + + def __call__(self, x, mask, cache): + x = self.embed_tokens(x) + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + x, cache[e] = layer(x, mask, cache[e]) + return self.final_layernorm(x), cache + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.model = Transformer(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + + def __call__( + self, + x: mx.array, + mask: mx.array = None, + cache: mx.array = None, + ) -> tuple[mx.array, mx.array]: + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) + + y, cache = self.model(x, mask, cache) + return self.lm_head(y), cache diff --git a/llms/hf_llm/utils.py b/llms/hf_llm/utils.py new file mode 100644 index 00000000..71d8941f --- /dev/null +++ b/llms/hf_llm/utils.py @@ -0,0 +1,141 @@ +import glob +import json +import logging +from pathlib import Path +from typing import Generator, Tuple + +import mlx.core as mx +import mlx.nn as nn + +# Local imports +import models.llama as llama +import models.phi2 as phi2 +from huggingface_hub import snapshot_download +from models.base import BaseModelArgs +from transformers import AutoTokenizer, PreTrainedTokenizer + +# Constants +MODEL_MAPPING = { + "llama": llama, + "mistral": llama, # mistral is compatible with llama + "phi": phi2, +} + + +def _get_classes(config: dict): + """ + Retrieve the model and model args classes based on the configuration. + + Args: + config (dict): The model configuration. + + Returns: + A tuple containing the Model class and the ModelArgs class. + """ + model_type = config["model_type"] + if model_type not in MODEL_MAPPING: + msg = f"Model type {model_type} not supported." + logging.error(msg) + raise ValueError(msg) + + arch = MODEL_MAPPING[model_type] + return arch.Model, arch.ModelArgs + + +def get_model_path(path_or_hf_repo: str) -> Path: + """ + Ensures the model is available locally. If the path does not exist locally, + it is downloaded from the Hugging Face Hub. + + Args: + path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. + + Returns: + Path: The path to the model. + """ + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"], + ) + ) + return model_path + + +def generate( + prompt: mx.array, model: nn.Module, temp: float = 0.0 +) -> Generator[mx.array, None, None]: + """ + Generate text based on the given prompt and model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + temp (float): The temperature for sampling. If temp is 0, use max sampling. + + Yields: + mx.array: The generated text. + """ + + def sample(logits: mx.array) -> mx.array: + return ( + mx.argmax(logits, axis=-1) + if temp == 0 + else mx.random.categorical(logits * (1 / temp)) + ) + + y = prompt + cache = None + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits) + yield y + + +def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: + """ + Load the model from a given path or a huggingface repository. + + Args: + path_or_hf_repo (str): The path or the huggingface repository to load the model from. + + Returns: + Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer. + + Raises: + FileNotFoundError: If config file or safetensors are not found. + ValueError: If model class or args class are not found. + """ + model_path = get_model_path(path_or_hf_repo) + + try: + with open(model_path / "config.json", "r") as f: + config = json.load(f) + quantization = config.get("quantization", None) + except FileNotFoundError: + logging.error(f"Config file not found in {model_path}") + raise + weight_files = glob.glob(str(model_path / "*.safetensors")) + if not weight_files: + logging.error(f"No safetensors found in {model_path}") + raise FileNotFoundError(f"No safetensors found in {model_path}") + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + model_class, model_args_class = _get_classes(config=config) + + model_args = model_args_class.from_dict(config) + model = model_class(model_args) + + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) + + model.load_weights(list(weights.items())) + + mx.eval(model.parameters()) + tokenizer = AutoTokenizer.from_pretrained(model_path) + return model, tokenizer diff --git a/llms/phi2/README.md b/llms/phi2/README.md deleted file mode 100644 index 086cd17e..00000000 --- a/llms/phi2/README.md +++ /dev/null @@ -1,58 +0,0 @@ -# Phi-2 - -Phi-2 is a 2.7B parameter language model released by Microsoft with -performance that rivals much larger models.[^1] It was trained on a mixture of -GPT-4 outputs and clean web text. - -Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit -precision. - -### Setup - -Install the dependencies: - -``` -pip install -r requirements.txt -``` - -### Run -``` -python generate.py --model --prompt "hello" -``` -For example: - -``` -python generate.py --model microsoft/phi-2 --prompt "hello" -``` -The `` should be either a path to a local directory or a Hugging -Face repo with weights stored in `safetensors` format. If you use a repo from -the Hugging Face Hub, then the model will be downloaded and cached the first -time you run it. - -Run `python generate.py --help` to see all the options. - -### Convert new models - -You can convert (change the data type or quantize) models using the -`convert.py` script. This script takes a Hugging Face repo as input and outputs -a model directory (which you can optionally also upload to Hugging Face). - -For example, to make 4-bit quantized a model, run: - -``` -python convert.py --hf-path -q -``` - -For more options run: - -``` -python convert.py --help -``` - -You can upload new models to the [Hugging Face MLX -Community](https://huggingface.co/mlx-community) by specifying `--upload-name`` -to `convert.py`. - -[^1]: For more details on the model see the [blog post]( -https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) -and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2) \ No newline at end of file diff --git a/llms/phi2/convert.py b/llms/phi2/convert.py deleted file mode 100644 index 4cac6e82..00000000 --- a/llms/phi2/convert.py +++ /dev/null @@ -1,172 +0,0 @@ -import argparse -import copy -import glob -import json -from pathlib import Path - -import mlx.core as mx -import mlx.nn as nn -import transformers -from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten -from phi2 import Model, ModelArgs - - -def fetch_from_hub(hf_path: str): - model_path = snapshot_download( - repo_id=hf_path, - allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], - ) - weight_files = glob.glob(f"{model_path}/*.safetensors") - if len(weight_files) == 0: - raise FileNotFoundError("No safetensors found in {}".format(model_path)) - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf).items()) - - config = transformers.AutoConfig.from_pretrained(hf_path, trust_remote_code=True) - tokenizer = transformers.AutoTokenizer.from_pretrained( - hf_path, - ) - return weights, config.to_dict(), tokenizer - - -def quantize(weights, config, args): - quantized_config = copy.deepcopy(config) - - # Load the model: - model = Model(ModelArgs.from_dict(config)) - model.load_weights(list(weights.items())) - - # Quantize the model: - nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) - - # Update the config: - quantized_config["quantization"] = { - "group_size": args.q_group_size, - "bits": args.q_bits, - } - quantized_weights = dict(tree_flatten(model.parameters())) - - return quantized_weights, quantized_config - - -def make_shards(weights: dict, max_file_size_gibibyte: int = 15): - max_file_size_bytes = max_file_size_gibibyte << 30 - shards = [] - shard, shard_size = {}, 0 - for k, v in weights.items(): - estimated_size = v.size * v.dtype.size - if shard_size + estimated_size > max_file_size_bytes: - shards.append(shard) - shard, shard_size = {}, 0 - shard[k] = v - shard_size += estimated_size - shards.append(shard) - return shards - - -def upload_to_hub(path: str, name: str, hf_path: str): - import os - - from huggingface_hub import HfApi, ModelCard, logging - - repo_id = f"mlx-community/{name}" - - card = ModelCard.load(hf_path) - card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] - card.text = f""" -# {name} -This model was converted to MLX format from [`{hf_path}`](). -Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. -## Use with mlx -```bash -pip install mlx -git clone https://github.com/ml-explore/mlx-examples.git -cd mlx-examples/llms/hf_llm -python generate.py --model {repo_id} --prompt "My name is" -``` -""" - card.save(os.path.join(path, "README.md")) - - logging.set_verbosity_info() - - api = HfApi() - api.create_repo(repo_id=repo_id, exist_ok=True) - api.upload_folder( - folder_path=path, - repo_id=repo_id, - repo_type="model", - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Convert Hugging Face model to MLX format" - ) - parser.add_argument( - "--hf-path", - type=str, - help="Path to the Hugging Face model.", - ) - parser.add_argument( - "--mlx-path", - type=str, - default="mlx_model", - help="Path to save the MLX model.", - ) - parser.add_argument( - "-q", - "--quantize", - help="Generate a quantized model.", - action="store_true", - ) - parser.add_argument( - "--q-group-size", - help="Group size for quantization.", - type=int, - default=64, - ) - parser.add_argument( - "--q-bits", - help="Bits per weight for quantization.", - type=int, - default=4, - ) - parser.add_argument( - "--dtype", - help="Type to save the parameters, ignored if -q is given.", - type=str, - choices=["float16", "bfloat16", "float32"], - default="float16", - ) - parser.add_argument( - "--upload-name", - help="The name of model to upload to Hugging Face MLX Community", - type=str, - default=None, - ) - - args = parser.parse_args() - - print("[INFO] Loading") - weights, config, tokenizer = fetch_from_hub(args.hf_path) - - dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) - weights = {k: v.astype(dtype) for k, v in weights.items()} - if args.quantize: - print("[INFO] Quantizing") - weights, config = quantize(weights, config, args) - - mlx_path = Path(args.mlx_path) - mlx_path.mkdir(parents=True, exist_ok=True) - shards = make_shards(weights) - for i, shard in enumerate(shards): - mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard) - tokenizer.save_pretrained(mlx_path) - with open(mlx_path / "config.json", "w") as fid: - json.dump(config, fid, indent=4) - - if args.upload_name is not None: - upload_to_hub(mlx_path, args.upload_name, args.hf_path) diff --git a/llms/phi2/generate.py b/llms/phi2/generate.py deleted file mode 100644 index 2f176801..00000000 --- a/llms/phi2/generate.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import argparse -import time - -import mlx.core as mx -import phi2 -import transformers - - -def generate( - model: phi2.Model, - tokenizer: transformers.AutoTokenizer, - prompt: str, - max_tokens: int, - temp: float = 0.0, -): - print("[INFO] Generating with Phi-2...", flush=True) - print(prompt, end="", flush=True) - prompt = tokenizer( - prompt, - return_tensors="np", - return_attention_mask=False, - )[ - "input_ids" - ][0] - prompt = mx.array(prompt) - - tic = time.time() - tokens = [] - skip = 0 - for token, n in zip( - phi2.generate(prompt, model, temp), - range(max_tokens), - ): - if token == tokenizer.eos_token_id: - break - - if n == 0: - prompt_time = time.time() - tic - tic = time.time() - - tokens.append(token.item()) - # if (n + 1) % 10 == 0: - s = tokenizer.decode(tokens) - print(s[skip:], end="", flush=True) - skip = len(s) - print(tokenizer.decode(tokens)[skip:], flush=True) - gen_time = time.time() - tic - print("=" * 10) - if len(tokens) == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt.size / prompt_time - gen_tps = (len(tokens) - 1) / gen_time - print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {gen_tps:.3f} tokens-per-sec") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="inference script") - parser.add_argument( - "--model", - type=str, - default="mlx_model", - help="The path to the local model directory or Hugging Face repo.", - ) - parser.add_argument( - "--prompt", - help="The message to be processed by the model", - default="Write a detailed analogy between mathematics and a lighthouse.", - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temp", - help="The sampling temperature.", - type=float, - default=0.0, - ) - parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") - - args = parser.parse_args() - mx.random.seed(args.seed) - model, tokenizer = phi2.load(args.model) - generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) diff --git a/llms/phi2/phi2.py b/llms/phi2/phi2.py deleted file mode 100644 index ae47ee26..00000000 --- a/llms/phi2/phi2.py +++ /dev/null @@ -1,224 +0,0 @@ -import glob -import inspect -import json -import math -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import mlx.core as mx -import mlx.nn as nn -from huggingface_hub import snapshot_download -from mlx.utils import tree_unflatten -from transformers import AutoTokenizer - - -@dataclass -class ModelArgs: - max_sequence_length: int = 2048 - num_vocab: int = 51200 - model_dim: int = 2560 - num_heads: int = 32 - num_layers: int = 32 - rotary_dim: int = 32 - - @classmethod - def from_dict(cls, params): - return cls( - **{ - k: v - for k, v in params.items() - if k in inspect.signature(cls).parameters - } - ) - - -class LayerNorm(nn.LayerNorm): - def __call__(self, x: mx.array) -> mx.array: - return super().__call__(x.astype(mx.float32)).astype(x.dtype) - - -class RoPEAttention(nn.Module): - def __init__(self, dims: int, num_heads: int, rotary_dim: int): - super().__init__() - - self.num_heads = num_heads - - self.rope = nn.RoPE(rotary_dim, traditional=False) - self.Wqkv = nn.Linear(dims, 3 * dims) - self.out_proj = nn.Linear(dims, dims) - - def __call__(self, x, mask=None, cache=None): - qkv = self.Wqkv(x) - queries, keys, values = mx.split(qkv, 3, axis=-1) - - # Extract some shapes - num_heads = self.num_heads - B, L, D = queries.shape - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - - # Add RoPE to the queries and keys and combine them with the cache - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - queries = queries.astype(mx.float32) - keys = keys.astype(mx.float32) - - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores = scores + mask - - scores = mx.softmax(scores, axis=-1).astype(values.dtype) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.out_proj(values_hat), (keys, values) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.fc1 = nn.Linear(dim, hidden_dim) - self.fc2 = nn.Linear(hidden_dim, dim) - self.act = nn.GELU(approx="precise") - - def __call__(self, x) -> mx.array: - return self.fc2(self.act(self.fc1(x))) - - -class ParallelBlock(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - dims = config.model_dim - mlp_dims = dims * 4 - self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) - self.ln = LayerNorm(dims) - self.mlp = MLP(dims, mlp_dims) - - def __call__(self, x, mask, cache): - h = self.ln(x) - attn_h, cache = self.mixer(h, mask, cache) - ff_h = self.mlp(h) - return attn_h + ff_h + x, cache - - -class TransformerDecoder(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.embd = Embd(config) - self.h = [ParallelBlock(config) for i in range(config.num_layers)] - - def __call__(self, x, mask, cache): - x = self.embd(x) - if cache is None: - cache = [None] * len(self.h) - - for e, layer in enumerate(self.h): - x, cache[e] = layer(x, mask, cache[e]) - return x, cache - - -class Embd(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.wte = nn.Embedding(config.num_vocab, config.model_dim) - - def __call__(self, x): - return self.wte(x) - - -class OutputHead(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.ln = LayerNorm(config.model_dim) - self.linear = nn.Linear(config.model_dim, config.num_vocab) - - def __call__(self, inputs): - return self.linear(self.ln(inputs)) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.transformer = TransformerDecoder(config) - self.lm_head = OutputHead(config) - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache: mx.array = None, - ) -> tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - - y, cache = self.transformer(x, mask, cache) - return self.lm_head(y), cache - - -def generate(prompt: mx.array, model: Model, temp: float = 0.0): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temp)) - - y = prompt - cache = None - while True: - logits, cache = model(y[None], cache=cache) - logits = logits[:, -1, :] - y = sample(logits) - yield y - - -def load(path_or_hf_repo: str): - # If the path exists, it will try to load model form it - # otherwise download and cache from the hf_repo and cache - model_path = Path(path_or_hf_repo) - if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], - ) - ) - - with open(model_path / "config.json", "r") as f: - config = json.loads(f.read()) - quantization = config.get("quantization", None) - model_args = ModelArgs.from_dict(config) - - weight_files = glob.glob(str(model_path / "*.safetensors")) - if len(weight_files) == 0: - raise FileNotFoundError("No safetensors found in {}".format(model_path)) - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf).items()) - - model = Model(model_args) - if quantization is not None: - nn.QuantizedLinear.quantize_module(model, **quantization) - - model.load_weights(list(weights.items())) - - mx.eval(model.parameters()) - tokenizer = AutoTokenizer.from_pretrained( - model_path, - ) - return model, tokenizer diff --git a/llms/phi2/requirements.txt b/llms/phi2/requirements.txt deleted file mode 100644 index 44ee5f6d..00000000 --- a/llms/phi2/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -einops -mlx -numpy -transformers>=4.35 -torch