refactor(hf_llm): moving phi2 example into hf_llm (#293)

* refactor: moving phi2 example into hf_llm

* chore: clean up

* chore: update phi2 model args so it can load args from config

* fix phi2 + nits + readme

* allow any HF repo, update README

* fix bug in llama

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen 2024-01-11 12:29:12 -08:00 committed by GitHub
parent e74889d0fa
commit a2402116ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 647 additions and 697 deletions

1
llms/hf_llm/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
mlx_model

View File

@ -35,7 +35,7 @@ Run `python generate.py --help` to see all the options.
### Models ### Models
The example supports Hugging Face format Mistral and Llama-style models. If the The example supports Hugging Face format Mistral, Llama, and Phi-2 style models. If the
model you want to run is not supported, file an model you want to run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, [issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request. submit a pull request.
@ -47,11 +47,13 @@ Here are a few examples of Hugging Face models that work with this example:
- [TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T) - [TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T)
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) - [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat) - [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
Most Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending) [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
and and
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending) [Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending)
style models should work out of the box. style models should work out of the box.
### Convert new models ### Convert new models
@ -72,6 +74,13 @@ For more options run:
python convert.py --help python convert.py --help
``` ```
You can upload new models to the [Hugging Face MLX You can upload new models to Hugging Face by specifying `--upload-repo` to
Community](https://huggingface.co/mlx-community) by specifying `--upload-name` `convert.py`. For example, to upload a quantized Mistral-7B model to the
to `convert.py`. [MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
```
python convert.py \
--hf-path mistralai/Mistral-7B-v0.1 \
-q \
--upload mlx-community/my-4bit-mistral \
```

View File

@ -1,52 +1,95 @@
# Copyright © 2023 Apple Inc.
import argparse import argparse
import copy import copy
import glob import glob
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import transformers import transformers
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from models import Model, ModelArgs from utils import get_model_path, load
MAX_FILE_SIZE_GB = 15
def fetch_from_hub(model_path: str, local: bool): def configure_parser() -> argparse.ArgumentParser:
if not local: """
model_path = snapshot_download( Configures and returns the argument parser for the script.
repo_id=model_path,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], Returns:
) argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
parser.add_argument(
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
)
parser.add_argument(
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
)
parser.add_argument(
"--q-group-size", help="Group size for quantization.", type=int, default=64
)
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
return parser
def fetch_from_hub(
model_path: str,
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
model_path = get_model_path(model_path)
weight_files = glob.glob(f"{model_path}/*.safetensors") weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0: if not weight_files:
raise FileNotFoundError("No safetensors found in {}".format(model_path)) raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {} weights = {}
for wf in weight_files: for wf in weight_files:
weights.update(mx.load(wf).items()) weights.update(mx.load(wf).items())
config = transformers.AutoConfig.from_pretrained(model_path) config = transformers.AutoConfig.from_pretrained(model_path)
tokenizer = transformers.AutoTokenizer.from_pretrained( tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
model_path,
)
return weights, config.to_dict(), tokenizer return weights, config.to_dict(), tokenizer
def quantize(weights, config, args): def quantize(weights: dict, config: dict, args: argparse.Namespace) -> tuple:
quantized_config = copy.deepcopy(config) """
Applies quantization to the model weights.
# Load the model: Args:
model = Model(ModelArgs.from_dict(config)) weights (dict): Model weights.
config (dict): Model configuration.
args (argparse.Namespace): Command-line arguments.
Returns:
tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
model, _ = load(args.hf_path)
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = { quantized_config["quantization"] = {
"group_size": args.q_group_size, "group_size": args.q_group_size,
"bits": args.q_bits, "bits": args.q_bits,
@ -56,8 +99,18 @@ def quantize(weights, config, args):
return quantized_weights, quantized_config return quantized_weights, quantized_config
def make_shards(weights: dict, max_file_size_gibibyte: int = 15): def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
max_file_size_bytes = max_file_size_gibibyte << 30 """
Splits the weights into smaller shards.
Args:
weights (dict): Model weights.
max_file_size_gb (int): Maximum size of each shard in gigabytes.
Returns:
list: List of weight shards.
"""
max_file_size_bytes = max_file_size_gb << 30
shards = [] shards = []
shard, shard_size = {}, 0 shard, shard_size = {}, 0
for k, v in weights.items(): for k, v in weights.items():
@ -71,17 +124,23 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
return shards return shards
def upload_to_hub(path: str, name: str, hf_path: str): def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os import os
from huggingface_hub import HfApi, ModelCard, logging from huggingface_hub import HfApi, ModelCard, logging
repo_id = f"mlx-community/{name}"
card = ModelCard.load(hf_path) card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f""" card.text = f"""
# {name} # {upload_repo}
This model was converted to MLX format from [`{hf_path}`](). This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx ## Use with mlx
@ -97,72 +156,20 @@ python generate.py --model {repo_id} --prompt "My name is"
logging.set_verbosity_info() logging.set_verbosity_info()
api = HfApi() api = HfApi()
api.create_repo(repo_id=repo_id, exist_ok=True) api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder( api.upload_folder(
folder_path=path, folder_path=path,
repo_id=repo_id, repo_id=upload_repo,
repo_type="model", repo_type="model",
) )
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = configure_parser()
description="Convert Hugging Face model to MLX format"
)
parser.add_argument(
"--hf-path",
type=str,
help="Path to the Hugging Face model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
parser.add_argument(
"-l",
"--local",
action="store_true",
help="Whether the hf-path points to a local filesystem.",
default=False,
)
args = parser.parse_args() args = parser.parse_args()
print("[INFO] Loading") print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local) weights, config, tokenizer = fetch_from_hub(args.hf_path)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()} weights = {k: v.astype(dtype) for k, v in weights.items()}
@ -179,5 +186,5 @@ if __name__ == "__main__":
with open(mlx_path / "config.json", "w") as fid: with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4) json.dump(config, fid, indent=4)
if args.upload_name is not None and not args.local: if args.upload_repo is not None:
upload_to_hub(mlx_path, args.upload_name, args.hf_path) upload_to_hub(mlx_path, args.upload_repo, args.hf_path)

View File

@ -1,43 +1,58 @@
# Copyright © 2023 Apple Inc.
import argparse import argparse
import time import time
import mlx.core as mx import mlx.core as mx
import models from utils import generate, load
import transformers
DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_SEED = 0
def generate( def setup_arg_parser():
model: models.Model, """Set up and return the argument parser."""
tokenizer: transformers.AutoTokenizer, parser = argparse.ArgumentParser(description="LLM inference script")
prompt: str, parser.add_argument(
max_tokens: int, "--model",
temp: float = 0.0, type=str,
): default="mlx_model",
prompt = tokenizer( help="The path to the local model directory or Hugging Face repo.",
prompt, )
return_tensors="np", parser.add_argument(
return_attention_mask=False, "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
)[ )
"input_ids" parser.add_argument(
][0] "--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
return parser
def main(args):
mx.random.seed(args.seed)
model, tokenizer = load(args.model)
print("=" * 10)
print("Prompt:", args.prompt)
prompt = tokenizer.encode(args.prompt)
prompt = mx.array(prompt) prompt = mx.array(prompt)
tic = time.time() tic = time.time()
tokens = [] tokens = []
skip = 0 skip = 0
for token, n in zip( for token, n in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
models.generate(prompt, model, temp),
range(max_tokens),
):
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
if n == 0: if n == 0:
prompt_time = time.time() - tic prompt_time = time.time() - tic
tic = time.time() tic = time.time()
tokens.append(token.item()) tokens.append(token.item())
s = tokenizer.decode(tokens) s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True) print(s[skip:], end="", flush=True)
@ -55,34 +70,6 @@ def generate(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="inference script") parser = setup_arg_parser()
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args() args = parser.parse_args()
mx.random.seed(args.seed) main(args)
model, tokenizer = models.load(args.model)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)

0
llms/hf_llm/models/.gitignore vendored Normal file
View File

View File

View File

@ -0,0 +1,15 @@
import inspect
from dataclasses import dataclass
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)

202
llms/hf_llm/models/llama.py Normal file
View File

@ -0,0 +1,202 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache

138
llms/hf_llm/models/phi2.py Normal file
View File

@ -0,0 +1,138 @@
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
n_positions: int = 2048
vocab_size: int = 51200
n_embd: int = 2560
n_head: int = 32
n_layer: int = 32
rotary_dim: int = 32
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, n_head: int, rotary_dim: int):
super().__init__()
self.n_head = n_head
self.q_proj = nn.Linear(dims, dims)
self.k_proj = nn.Linear(dims, dims)
self.v_proj = nn.Linear(dims, dims)
self.dense = nn.Linear(dims, dims)
self.rope = nn.RoPE(rotary_dim, traditional=False)
def __call__(self, x, mask=None, cache=None):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Extract some shapes
n_head = self.n_head
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(values_hat), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.n_embd
mlp_dims = dims * 4
self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim)
self.input_layernorm = LayerNorm(dims)
self.mlp = MLP(dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
attn_h, cache = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x, cache
class Transformer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
self.layers = [ParallelBlock(config) for i in range(config.n_layer)]
self.final_layernorm = LayerNorm(config.n_embd)
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
return self.final_layernorm(x), cache
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model = Transformer(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.model(x, mask, cache)
return self.lm_head(y), cache

141
llms/hf_llm/utils.py Normal file
View File

@ -0,0 +1,141 @@
import glob
import json
import logging
from pathlib import Path
from typing import Generator, Tuple
import mlx.core as mx
import mlx.nn as nn
# Local imports
import models.llama as llama
import models.phi2 as phi2
from huggingface_hub import snapshot_download
from models.base import BaseModelArgs
from transformers import AutoTokenizer, PreTrainedTokenizer
# Constants
MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"phi": phi2,
}
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
if model_type not in MODEL_MAPPING:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
arch = MODEL_MAPPING[model_type]
return arch.Model, arch.ModelArgs
def get_model_path(path_or_hf_repo: str) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
Returns:
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"],
)
)
return model_path
def generate(
prompt: mx.array, model: nn.Module, temp: float = 0.0
) -> Generator[mx.array, None, None]:
"""
Generate text based on the given prompt and model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling. If temp is 0, use max sampling.
Yields:
mx.array: The generated text.
"""
def sample(logits: mx.array) -> mx.array:
return (
mx.argmax(logits, axis=-1)
if temp == 0
else mx.random.categorical(logits * (1 / temp))
)
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
Load the model from a given path or a huggingface repository.
Args:
path_or_hf_repo (str): The path or the huggingface repository to load the model from.
Returns:
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
Raises:
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = get_model_path(path_or_hf_repo)
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
quantization = config.get("quantization", None)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer

View File

@ -1,58 +0,0 @@
# Phi-2
Phi-2 is a 2.7B parameter language model released by Microsoft with
performance that rivals much larger models.[^1] It was trained on a mixture of
GPT-4 outputs and clean web text.
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
precision.
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
### Run
```
python generate.py --model <model_path> --prompt "hello"
```
For example:
```
python generate.py --model microsoft/phi-2 --prompt "hello"
```
The `<model_path>` should be either a path to a local directory or a Hugging
Face repo with weights stored in `safetensors` format. If you use a repo from
the Hugging Face Hub, then the model will be downloaded and cached the first
time you run it.
Run `python generate.py --help` to see all the options.
### Convert new models
You can convert (change the data type or quantize) models using the
`convert.py` script. This script takes a Hugging Face repo as input and outputs
a model directory (which you can optionally also upload to Hugging Face).
For example, to make 4-bit quantized a model, run:
```
python convert.py --hf-path <hf_repo> -q
```
For more options run:
```
python convert.py --help
```
You can upload new models to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community) by specifying `--upload-name``
to `convert.py`.
[^1]: For more details on the model see the [blog post](
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)

View File

@ -1,172 +0,0 @@
import argparse
import copy
import glob
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import transformers
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from phi2 import Model, ModelArgs
def fetch_from_hub(hf_path: str):
model_path = snapshot_download(
repo_id=hf_path,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
config = transformers.AutoConfig.from_pretrained(hf_path, trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(
hf_path,
)
return weights, config.to_dict(), tokenizer
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = Model(ModelArgs.from_dict(config))
model.load_weights(list(weights.items()))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
def upload_to_hub(path: str, name: str, hf_path: str):
import os
from huggingface_hub import HfApi, ModelCard, logging
repo_id = f"mlx-community/{name}"
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {name}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install mlx
git clone https://github.com/ml-explore/mlx-examples.git
cd mlx-examples/llms/hf_llm
python generate.py --model {repo_id} --prompt "My name is"
```
"""
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=repo_id,
repo_type="model",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument(
"--hf-path",
type=str,
help="Path to the Hugging Face model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
args = parser.parse_args()
print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(args.hf_path)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
for i, shard in enumerate(shards):
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if args.upload_name is not None:
upload_to_hub(mlx_path, args.upload_name, args.hf_path)

View File

@ -1,91 +0,0 @@
# Copyright © 2023 Apple Inc.
import argparse
import time
import mlx.core as mx
import phi2
import transformers
def generate(
model: phi2.Model,
tokenizer: transformers.AutoTokenizer,
prompt: str,
max_tokens: int,
temp: float = 0.0,
):
print("[INFO] Generating with Phi-2...", flush=True)
print(prompt, end="", flush=True)
prompt = tokenizer(
prompt,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
prompt = mx.array(prompt)
tic = time.time()
tokens = []
skip = 0
for token, n in zip(
phi2.generate(prompt, model, temp),
range(max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
# if (n + 1) % 10 == 0:
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="inference script")
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = phi2.load(args.model)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)

View File

@ -1,224 +0,0 @@
import glob
import inspect
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
max_sequence_length: int = 2048
num_vocab: int = 51200
model_dim: int = 2560
num_heads: int = 32
num_layers: int = 32
rotary_dim: int = 32
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(rotary_dim, traditional=False)
self.Wqkv = nn.Linear(dims, 3 * dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x, mask=None, cache=None):
qkv = self.Wqkv(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.model_dim
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims)
self.mlp = MLP(dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h, cache = self.mixer(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embd = Embd(config)
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
def __call__(self, x, mask, cache):
x = self.embd(x)
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
return x, cache
class Embd(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
def __call__(self, x):
return self.wte(x)
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.ln = LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs):
return self.linear(self.ln(inputs))
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
model_args = ModelArgs.from_dict(config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
return model, tokenizer

View File

@ -1,5 +0,0 @@
einops
mlx
numpy
transformers>=4.35
torch