mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
refactor(hf_llm): moving phi2 example into hf_llm (#293)
* refactor: moving phi2 example into hf_llm * chore: clean up * chore: update phi2 model args so it can load args from config * fix phi2 + nits + readme * allow any HF repo, update README * fix bug in llama --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
e74889d0fa
commit
a2402116ae
1
llms/hf_llm/.gitignore
vendored
Normal file
1
llms/hf_llm/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
mlx_model
|
@ -35,7 +35,7 @@ Run `python generate.py --help` to see all the options.
|
||||
|
||||
### Models
|
||||
|
||||
The example supports Hugging Face format Mistral and Llama-style models. If the
|
||||
The example supports Hugging Face format Mistral, Llama, and Phi-2 style models. If the
|
||||
model you want to run is not supported, file an
|
||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||
submit a pull request.
|
||||
@ -47,11 +47,13 @@ Here are a few examples of Hugging Face models that work with this example:
|
||||
- [TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T)
|
||||
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
|
||||
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
|
||||
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
|
||||
|
||||
Most
|
||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending)
|
||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
|
||||
and
|
||||
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending)
|
||||
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending)
|
||||
style models should work out of the box.
|
||||
|
||||
### Convert new models
|
||||
@ -72,6 +74,13 @@ For more options run:
|
||||
python convert.py --help
|
||||
```
|
||||
|
||||
You can upload new models to the [Hugging Face MLX
|
||||
Community](https://huggingface.co/mlx-community) by specifying `--upload-name`
|
||||
to `convert.py`.
|
||||
You can upload new models to Hugging Face by specifying `--upload-repo` to
|
||||
`convert.py`. For example, to upload a quantized Mistral-7B model to the
|
||||
[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
|
||||
|
||||
```
|
||||
python convert.py \
|
||||
--hf-path mistralai/Mistral-7B-v0.1 \
|
||||
-q \
|
||||
--upload mlx-community/my-4bit-mistral \
|
||||
```
|
||||
|
@ -1,52 +1,95 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from models import Model, ModelArgs
|
||||
from utils import get_model_path, load
|
||||
|
||||
MAX_FILE_SIZE_GB = 15
|
||||
|
||||
|
||||
def fetch_from_hub(model_path: str, local: bool):
|
||||
if not local:
|
||||
model_path = snapshot_download(
|
||||
repo_id=model_path,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face model to MLX format"
|
||||
)
|
||||
|
||||
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
||||
parser.add_argument(
|
||||
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the parameters, ignored if -q is given.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def fetch_from_hub(
|
||||
model_path: str,
|
||||
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
|
||||
model_path = get_model_path(model_path)
|
||||
|
||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
return weights, config.to_dict(), tokenizer
|
||||
|
||||
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
def quantize(weights: dict, config: dict, args: argparse.Namespace) -> tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
|
||||
# Load the model:
|
||||
model = Model(ModelArgs.from_dict(config))
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
config (dict): Model configuration.
|
||||
args (argparse.Namespace): Command-line arguments.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
model, _ = load(args.hf_path)
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {
|
||||
"group_size": args.q_group_size,
|
||||
"bits": args.q_bits,
|
||||
@ -56,8 +99,18 @@ def quantize(weights, config, args):
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
||||
"""
|
||||
Splits the weights into smaller shards.
|
||||
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
||||
|
||||
Returns:
|
||||
list: List of weight shards.
|
||||
"""
|
||||
max_file_size_bytes = max_file_size_gb << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
@ -71,17 +124,23 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
return shards
|
||||
|
||||
|
||||
def upload_to_hub(path: str, name: str, hf_path: str):
|
||||
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
"""
|
||||
Uploads the model to Hugging Face hub.
|
||||
|
||||
Args:
|
||||
path (str): Local path to the model.
|
||||
upload_repo (str): Name of the HF repo to upload to.
|
||||
hf_path (str): Path to the original Hugging Face model.
|
||||
"""
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, ModelCard, logging
|
||||
|
||||
repo_id = f"mlx-community/{name}"
|
||||
|
||||
card = ModelCard.load(hf_path)
|
||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||
card.text = f"""
|
||||
# {name}
|
||||
# {upload_repo}
|
||||
This model was converted to MLX format from [`{hf_path}`]().
|
||||
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
||||
## Use with mlx
|
||||
@ -97,72 +156,20 @@ python generate.py --model {repo_id} --prompt "My name is"
|
||||
logging.set_verbosity_info()
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, exist_ok=True)
|
||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=path,
|
||||
repo_id=repo_id,
|
||||
repo_id=upload_repo,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face model to MLX format"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
type=str,
|
||||
help="Path to the Hugging Face model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Generate a quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size",
|
||||
help="Group size for quantization.",
|
||||
type=int,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits",
|
||||
help="Bits per weight for quantization.",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the parameters, ignored if -q is given.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-name",
|
||||
help="The name of model to upload to Hugging Face MLX Community",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--local",
|
||||
action="store_true",
|
||||
help="Whether the hf-path points to a local filesystem.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
print("[INFO] Loading")
|
||||
weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local)
|
||||
weights, config, tokenizer = fetch_from_hub(args.hf_path)
|
||||
|
||||
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
@ -179,5 +186,5 @@ if __name__ == "__main__":
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
if args.upload_name is not None and not args.local:
|
||||
upload_to_hub(mlx_path, args.upload_name, args.hf_path)
|
||||
if args.upload_repo is not None:
|
||||
upload_to_hub(mlx_path, args.upload_repo, args.hf_path)
|
||||
|
@ -1,43 +1,58 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import models
|
||||
import transformers
|
||||
from utils import generate, load
|
||||
|
||||
DEFAULT_MODEL_PATH = "mlx_model"
|
||||
DEFAULT_PROMPT = "hello"
|
||||
DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.6
|
||||
DEFAULT_SEED = 0
|
||||
|
||||
|
||||
def generate(
|
||||
model: models.Model,
|
||||
tokenizer: transformers.AutoTokenizer,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temp: float = 0.0,
|
||||
):
|
||||
prompt = tokenizer(
|
||||
prompt,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)[
|
||||
"input_ids"
|
||||
][0]
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(description="LLM inference script")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
mx.random.seed(args.seed)
|
||||
model, tokenizer = load(args.model)
|
||||
print("=" * 10)
|
||||
print("Prompt:", args.prompt)
|
||||
prompt = tokenizer.encode(args.prompt)
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
models.generate(prompt, model, temp),
|
||||
range(max_tokens),
|
||||
):
|
||||
for token, n in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if n == 0:
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
|
||||
tokens.append(token.item())
|
||||
s = tokenizer.decode(tokens)
|
||||
print(s[skip:], end="", flush=True)
|
||||
@ -55,34 +70,6 @@ def generate(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="inference script")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
mx.random.seed(args.seed)
|
||||
model, tokenizer = models.load(args.model)
|
||||
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)
|
||||
main(args)
|
||||
|
0
llms/hf_llm/models/.gitignore
vendored
Normal file
0
llms/hf_llm/models/.gitignore
vendored
Normal file
0
llms/hf_llm/models/__init__.py
Normal file
0
llms/hf_llm/models/__init__.py
Normal file
15
llms/hf_llm/models/base.py
Normal file
15
llms/hf_llm/models/base.py
Normal file
@ -0,0 +1,15 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArgs:
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
202
llms/hf_llm/models/llama.py
Normal file
202
llms/hf_llm/models/llama.py
Normal file
@ -0,0 +1,202 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.norm(h), cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model = LlamaModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
138
llms/hf_llm/models/phi2.py
Normal file
138
llms/hf_llm/models/phi2.py
Normal file
@ -0,0 +1,138 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
n_positions: int = 2048
|
||||
vocab_size: int = 51200
|
||||
n_embd: int = 2560
|
||||
n_head: int = 32
|
||||
n_layer: int = 32
|
||||
rotary_dim: int = 32
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, n_head: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
|
||||
self.q_proj = nn.Linear(dims, dims)
|
||||
self.k_proj = nn.Linear(dims, dims)
|
||||
self.v_proj = nn.Linear(dims, dims)
|
||||
self.dense = nn.Linear(dims, dims)
|
||||
|
||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Extract some shapes
|
||||
n_head = self.n_head
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
keys = keys.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.dense(values_hat), (keys, values)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
dims = config.n_embd
|
||||
mlp_dims = dims * 4
|
||||
self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim)
|
||||
self.input_layernorm = LayerNorm(dims)
|
||||
self.mlp = MLP(dims, mlp_dims)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.input_layernorm(x)
|
||||
attn_h, cache = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x, cache
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.layers = [ParallelBlock(config) for i in range(config.n_layer)]
|
||||
self.final_layernorm = LayerNorm(config.n_embd)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
return self.final_layernorm(x), cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model = Transformer(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y, cache = self.model(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
141
llms/hf_llm/utils.py
Normal file
141
llms/hf_llm/utils.py
Normal file
@ -0,0 +1,141 @@
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Generator, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
# Local imports
|
||||
import models.llama as llama
|
||||
import models.phi2 as phi2
|
||||
from huggingface_hub import snapshot_download
|
||||
from models.base import BaseModelArgs
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# Constants
|
||||
MODEL_MAPPING = {
|
||||
"llama": llama,
|
||||
"mistral": llama, # mistral is compatible with llama
|
||||
"phi": phi2,
|
||||
}
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Model class and the ModelArgs class.
|
||||
"""
|
||||
model_type = config["model_type"]
|
||||
if model_type not in MODEL_MAPPING:
|
||||
msg = f"Model type {model_type} not supported."
|
||||
logging.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
arch = MODEL_MAPPING[model_type]
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
"""
|
||||
Ensures the model is available locally. If the path does not exist locally,
|
||||
it is downloaded from the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
||||
|
||||
Returns:
|
||||
Path: The path to the model.
|
||||
"""
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"],
|
||||
)
|
||||
)
|
||||
return model_path
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: mx.array, model: nn.Module, temp: float = 0.0
|
||||
) -> Generator[mx.array, None, None]:
|
||||
"""
|
||||
Generate text based on the given prompt and model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
||||
|
||||
Yields:
|
||||
mx.array: The generated text.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> mx.array:
|
||||
return (
|
||||
mx.argmax(logits, axis=-1)
|
||||
if temp == 0
|
||||
else mx.random.categorical(logits * (1 / temp))
|
||||
)
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits)
|
||||
yield y
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
"""
|
||||
Load the model from a given path or a huggingface repository.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo (str): The path or the huggingface repository to load the model from.
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
ValueError: If model class or args class are not found.
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
try:
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
quantization = config.get("quantization", None)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Config file not found in {model_path}")
|
||||
raise
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
if not weight_files:
|
||||
logging.error(f"No safetensors found in {model_path}")
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = model_class(model_args)
|
||||
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
return model, tokenizer
|
@ -1,58 +0,0 @@
|
||||
# Phi-2
|
||||
|
||||
Phi-2 is a 2.7B parameter language model released by Microsoft with
|
||||
performance that rivals much larger models.[^1] It was trained on a mixture of
|
||||
GPT-4 outputs and clean web text.
|
||||
|
||||
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
|
||||
precision.
|
||||
|
||||
### Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Run
|
||||
```
|
||||
python generate.py --model <model_path> --prompt "hello"
|
||||
```
|
||||
For example:
|
||||
|
||||
```
|
||||
python generate.py --model microsoft/phi-2 --prompt "hello"
|
||||
```
|
||||
The `<model_path>` should be either a path to a local directory or a Hugging
|
||||
Face repo with weights stored in `safetensors` format. If you use a repo from
|
||||
the Hugging Face Hub, then the model will be downloaded and cached the first
|
||||
time you run it.
|
||||
|
||||
Run `python generate.py --help` to see all the options.
|
||||
|
||||
### Convert new models
|
||||
|
||||
You can convert (change the data type or quantize) models using the
|
||||
`convert.py` script. This script takes a Hugging Face repo as input and outputs
|
||||
a model directory (which you can optionally also upload to Hugging Face).
|
||||
|
||||
For example, to make 4-bit quantized a model, run:
|
||||
|
||||
```
|
||||
python convert.py --hf-path <hf_repo> -q
|
||||
```
|
||||
|
||||
For more options run:
|
||||
|
||||
```
|
||||
python convert.py --help
|
||||
```
|
||||
|
||||
You can upload new models to the [Hugging Face MLX
|
||||
Community](https://huggingface.co/mlx-community) by specifying `--upload-name``
|
||||
to `convert.py`.
|
||||
|
||||
[^1]: For more details on the model see the [blog post](
|
||||
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
|
||||
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)
|
@ -1,172 +0,0 @@
|
||||
import argparse
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from phi2 import Model, ModelArgs
|
||||
|
||||
|
||||
def fetch_from_hub(hf_path: str):
|
||||
model_path = snapshot_download(
|
||||
repo_id=hf_path,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(hf_path, trust_remote_code=True)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
hf_path,
|
||||
)
|
||||
return weights, config.to_dict(), tokenizer
|
||||
|
||||
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Load the model:
|
||||
model = Model(ModelArgs.from_dict(config))
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {
|
||||
"group_size": args.q_group_size,
|
||||
"bits": args.q_bits,
|
||||
}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
estimated_size = v.size * v.dtype.size
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
||||
def upload_to_hub(path: str, name: str, hf_path: str):
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, ModelCard, logging
|
||||
|
||||
repo_id = f"mlx-community/{name}"
|
||||
|
||||
card = ModelCard.load(hf_path)
|
||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||
card.text = f"""
|
||||
# {name}
|
||||
This model was converted to MLX format from [`{hf_path}`]().
|
||||
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
||||
## Use with mlx
|
||||
```bash
|
||||
pip install mlx
|
||||
git clone https://github.com/ml-explore/mlx-examples.git
|
||||
cd mlx-examples/llms/hf_llm
|
||||
python generate.py --model {repo_id} --prompt "My name is"
|
||||
```
|
||||
"""
|
||||
card.save(os.path.join(path, "README.md"))
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=path,
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face model to MLX format"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
type=str,
|
||||
help="Path to the Hugging Face model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Generate a quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size",
|
||||
help="Group size for quantization.",
|
||||
type=int,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits",
|
||||
help="Bits per weight for quantization.",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the parameters, ignored if -q is given.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-name",
|
||||
help="The name of model to upload to Hugging Face MLX Community",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("[INFO] Loading")
|
||||
weights, config, tokenizer = fetch_from_hub(args.hf_path)
|
||||
|
||||
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
if args.quantize:
|
||||
print("[INFO] Quantizing")
|
||||
weights, config = quantize(weights, config, args)
|
||||
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
shards = make_shards(weights)
|
||||
for i, shard in enumerate(shards):
|
||||
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
if args.upload_name is not None:
|
||||
upload_to_hub(mlx_path, args.upload_name, args.hf_path)
|
@ -1,91 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import phi2
|
||||
import transformers
|
||||
|
||||
|
||||
def generate(
|
||||
model: phi2.Model,
|
||||
tokenizer: transformers.AutoTokenizer,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temp: float = 0.0,
|
||||
):
|
||||
print("[INFO] Generating with Phi-2...", flush=True)
|
||||
print(prompt, end="", flush=True)
|
||||
prompt = tokenizer(
|
||||
prompt,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)[
|
||||
"input_ids"
|
||||
][0]
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
phi2.generate(prompt, model, temp),
|
||||
range(max_tokens),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if n == 0:
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
|
||||
tokens.append(token.item())
|
||||
# if (n + 1) % 10 == 0:
|
||||
s = tokenizer.decode(tokens)
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
||||
gen_time = time.time() - tic
|
||||
print("=" * 10)
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="inference script")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="Write a detailed analogy between mathematics and a lighthouse.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
mx.random.seed(args.seed)
|
||||
model, tokenizer = phi2.load(args.model)
|
||||
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)
|
@ -1,224 +0,0 @@
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
max_sequence_length: int = 2048
|
||||
num_vocab: int = 51200
|
||||
model_dim: int = 2560
|
||||
num_heads: int = 32
|
||||
num_layers: int = 32
|
||||
rotary_dim: int = 32
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||
self.Wqkv = nn.Linear(dims, 3 * dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.Wqkv(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
keys = keys.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
dims = config.model_dim
|
||||
mlp_dims = dims * 4
|
||||
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.ln = LayerNorm(dims)
|
||||
self.mlp = MLP(dims, mlp_dims)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.ln(x)
|
||||
attn_h, cache = self.mixer(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x, cache
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embd = Embd(config)
|
||||
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embd(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for e, layer in enumerate(self.h):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
return x, cache
|
||||
|
||||
|
||||
class Embd(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.wte(x)
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.ln = LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(self.ln(inputs))
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.transformer = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y, cache = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits)
|
||||
yield y
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str):
|
||||
# If the path exists, it will try to load model form it
|
||||
# otherwise download and cache from the hf_repo and cache
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
quantization = config.get("quantization", None)
|
||||
model_args = ModelArgs.from_dict(config)
|
||||
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
model = Model(model_args)
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
)
|
||||
return model, tokenizer
|
@ -1,5 +0,0 @@
|
||||
einops
|
||||
mlx
|
||||
numpy
|
||||
transformers>=4.35
|
||||
torch
|
Loading…
Reference in New Issue
Block a user