Mlx llm package (#301)

* fix converter

* add recursive files

* remove gitignore

* remove gitignore

* add packages properly

* read me update

* remove dup readme

* relative

* fix convert

* fix community name

* fix url

* version
This commit is contained in:
Awni Hannun 2024-01-12 10:25:56 -08:00 committed by GitHub
parent 2b61d9deb6
commit c6440416a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 270 additions and 388 deletions

2
llms/MANIFEST.in Normal file
View File

@ -0,0 +1,2 @@
include mlx_lm/requirements.txt
recursive-include mlx_lm/ *.py

110
llms/README.md Normal file
View File

@ -0,0 +1,110 @@
## Generate Text with LLMs and MLX
The easiest way to get started is to install the `mlx-lm` package:
```shell
pip install mlx-lm
```
### Python API
You can use `mlx-lm` as a module:
```python
from mlx_lm import load, generate
model, tokenizer = load("mistralai/Mistral-7B-v0.1")
response = generate(model, tokenizer, prompt="hello", verbose=True)
```
To see a description of all the arguments you can do:
```
>>> help(generate)
```
The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub.
You can convert models in the Python API with:
```python
from mlx_lm import convert
upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit"
convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo)
```
This will generate a 4-bit quantized Mistral-7B and upload it to the
repo `mlx-community/My-Mistral-7B-v0.1-4bit`. It will also save the
converted model in the path `mlx_model` by default.
To see a description of all the arguments you can do:
```
>>> help(convert)
```
### Command Line
You can also use `mlx-lm` from the command line with:
```
python -m mlx_lm.generate --model mistralai/Mistral-7B-v0.1 --prompt "hello"
```
This will download a Mistral 7B model from the Hugging Face Hub and generate
text using the given prompt.
For a full list of options run:
```
python -m mlx_lm generate --help
```
To quantize a model from the command line run:
```
python -m mlx_lm.convert --hf-path mistralai/Mistral-7B-v0.1 -q
```
For more options run:
```
python -m mlx_lm.convert --help
```
You can upload new models to Hugging Face by specifying `--upload-repo` to
`convert`. For example, to upload a quantized Mistral-7B model to the
[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
```
python -m mlx_lm.convert \
--hf-path mistralai/Mistral-7B-v0.1 \
-q \
--upload-repo mlx-community/my-4bit-mistral \
```
### Supported Models
The example supports Hugging Face format Mistral, Llama, and Phi-2 style
models. If the model you want to run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.
Here are a few examples of Hugging Face models that work with this example:
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
and
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending)
style models should work out of the box.

View File

@ -1 +0,0 @@
mlx_model

View File

@ -1,86 +0,0 @@
## Generate Text with MLX and :hugs: Hugging Face
This an example of large language model text generation that can pull models from
the Hugging Face Hub.
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
### Run
```
python generate.py --model <model_path> --prompt "hello"
```
For example:
```
python generate.py --model mistralai/Mistral-7B-v0.1 --prompt "hello"
```
will download the Mistral 7B model and generate text using the given prompt.
The `<model_path>` should be either a path to a local directory or a Hugging
Face repo with weights stored in `safetensors` format. If you use a repo from
the Hugging Face Hub, then the model will be downloaded and cached the first
time you run it. See the [Models](#models) section for a full list of supported models.
Run `python generate.py --help` to see all the options.
### Models
The example supports Hugging Face format Mistral, Llama, and Phi-2 style models. If the
model you want to run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.
Here are a few examples of Hugging Face models that work with this example:
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
- [TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T)
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
and
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending)
style models should work out of the box.
### Convert new models
You can convert (change the data type or quantize) models using the
`convert.py` script. This script takes a Hugging Face repo as input and outputs
a model directory (which you can optionally also upload to Hugging Face).
For example, to make a 4-bit quantized model, run:
```
python convert.py --hf-path <hf_repo> -q
```
For more options run:
```
python convert.py --help
```
You can upload new models to Hugging Face by specifying `--upload-repo` to
`convert.py`. For example, to upload a quantized Mistral-7B model to the
[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
```
python convert.py \
--hf-path mistralai/Mistral-7B-v0.1 \
-q \
--upload mlx-community/my-4bit-mistral \
```

View File

@ -1,269 +0,0 @@
# Copyright © 2023 Apple Inc.
import glob
import inspect
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
model_args = ModelArgs.from_dict(config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y

7
llms/mlx_lm/README.md Normal file
View File

@ -0,0 +1,7 @@
## Generate Text with MLX and :hugs: Hugging Face
This an example of large language model text generation that can pull models from
the Hugging Face Hub.
For more information on this example, see the
[README](../README.md) in the parent directory.

37
llms/mlx_lm/UPLOAD.md Normal file
View File

@ -0,0 +1,37 @@
### Packaging for PyPI
Install `build` and `twine`:
```
pip install --user --upgrade build
pip install --user --upgrade twine
```
Generate the source distribution and wheel:
```
python -m build
```
> [!warning]
> Use a test server first
#### Test Upload
Upload to test server:
```
python -m twine upload --repository testpypi dist/*
```
Install from test server and check that it works:
```
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm
```
#### Upload
```
python -m twine upload dist/*
```

2
llms/mlx_lm/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .convert import convert
from .utils import generate, load

View File

@ -9,7 +9,8 @@ import mlx.core as mx
import mlx.nn as nn
import transformers
from mlx.utils import tree_flatten
from utils import get_model_path, load
from .utils import get_model_path, load
MAX_FILE_SIZE_GB = 15
@ -73,26 +74,30 @@ def fetch_from_hub(
return weights, config.to_dict(), tokenizer
def quantize(weights: dict, config: dict, args: argparse.Namespace) -> tuple:
def quantize_model(
weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int
) -> tuple:
"""
Applies quantization to the model weights.
Args:
weights (dict): Model weights.
config (dict): Model configuration.
args (argparse.Namespace): Command-line arguments.
hf_path (str): HF model path..
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
Returns:
tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
model, _ = load(args.hf_path)
model, _ = load(hf_path)
model.load_weights(list(weights.items()))
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
nn.QuantizedLinear.quantize_module(model, q_group_size, q_bits)
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
"group_size": q_group_size,
"bits": q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
@ -148,7 +153,7 @@ Refer to the [original model card](https://huggingface.co/{hf_path}) for more de
pip install mlx
git clone https://github.com/ml-explore/mlx-examples.git
cd mlx-examples/llms/hf_llm
python generate.py --model {repo_id} --prompt "My name is"
python generate.py --model {upload_repo} --prompt "My name is"
```
"""
card.save(os.path.join(path, "README.md"))
@ -164,20 +169,24 @@ python generate.py --model {repo_id} --prompt "My name is"
)
if __name__ == "__main__":
parser = configure_parser()
args = parser.parse_args()
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
):
print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(args.hf_path)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights, config, tokenizer = fetch_from_hub(hf_path)
dtype = mx.float16 if quantize else getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if args.quantize:
if quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits)
mlx_path = Path(args.mlx_path)
mlx_path = Path(mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
for i, shard in enumerate(shards):
@ -186,5 +195,11 @@ if __name__ == "__main__":
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if args.upload_repo is not None:
upload_to_hub(mlx_path, args.upload_repo, args.hf_path)
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)
if __name__ == "__main__":
parser = configure_parser()
args = parser.parse_args()
convert(**vars(args))

View File

@ -2,7 +2,8 @@ import argparse
import time
import mlx.core as mx
from utils import generate, load
from .utils import generate_step, load
DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello"
@ -47,7 +48,9 @@ def main(args):
tic = time.time()
tokens = []
skip = 0
for token, n in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
for token, n in zip(
generate_step(prompt, model, args.temp), range(args.max_tokens)
):
if token == tokenizer.eos_token_id:
break
if n == 0:

View File

@ -1,4 +1,4 @@
mlx>=0.0.7
mlx
numpy
transformers
protobuf

View File

@ -6,13 +6,12 @@ from typing import Generator, Tuple
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports
import models.llama as llama
import models.phi2 as phi2
from huggingface_hub import snapshot_download
from models.base import BaseModelArgs
from transformers import AutoTokenizer, PreTrainedTokenizer
from .models import llama, phi2
from .models.base import BaseModelArgs
# Constants
MODEL_MAPPING = {
@ -64,11 +63,11 @@ def get_model_path(path_or_hf_repo: str) -> Path:
return model_path
def generate(
def generate_step(
prompt: mx.array, model: nn.Module, temp: float = 0.0
) -> Generator[mx.array, None, None]:
"""
Generate text based on the given prompt and model.
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
@ -76,7 +75,7 @@ def generate(
temp (float): The temperature for sampling. If temp is 0, use max sampling.
Yields:
mx.array: The generated text.
Generator[mx.array]: A generator producing one token per call.
"""
def sample(logits: mx.array) -> mx.array:
@ -95,6 +94,46 @@ def generate(
yield y
def generate(
model: nn.Module,
tokenizer: PreTrainedTokenizer,
prompt: str,
temp: float = 0.0,
max_tokens: int = 100,
verbose: bool = False,
) -> str:
"""
Generate text from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
temp (float): The temperature for sampling (default 0).
max_tokens (int): The maximum number of tokens (default 100).
"""
prompt = mx.array(tokenizer.encode(prompt))
tokens = []
skip = 0
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
if token == tokenizer.eos_token_id:
break
tokens.append(token.item())
if verbose:
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
tokens = tokenizer.decode(tokens)[skip:]
if verbose:
print(tokens, flush=True)
return tokens
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
Load the model from a given path or a huggingface repository.

23
llms/setup.py Normal file
View File

@ -0,0 +1,23 @@
import sys
from pathlib import Path
import pkg_resources
from setuptools import setup
with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
setup(
name="mlx-lm",
version="0.0.1",
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
readme="README.md",
author_email="mlx@group.apple.com",
author="MLX Contributors",
url="https://github.com/ml-explore/mlx-examples",
license="MIT",
install_requires=requirements,
packages=["mlx_lm", "mlx_lm.models"],
python_requires=">=3.8",
)