* initial

* file

* remove debug

* Adding README

* typo

* simplify readme

* nits in readmes

---------

Co-authored-by: Marcel Bischoff <marcel.bischoff@awarehq.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Marcel Bischoff 2024-01-13 11:35:03 -05:00 committed by GitHub
parent a39b735c3b
commit cd3cff0858
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 393 additions and 5 deletions

View File

@ -84,7 +84,7 @@ You can upload new models to Hugging Face by specifying `--upload-repo` to
python -m mlx_lm.convert \
--hf-path mistralai/Mistral-7B-v0.1 \
-q \
--upload-repo mlx-community/my-4bit-mistral \
--upload-repo mlx-community/my-4bit-mistral
```
### Supported Models

28
llms/phixtral/README.md Normal file
View File

@ -0,0 +1,28 @@
# Phixtral
Phixtral is a Mixture of Experts (MoE) architecture inspired by
[Mixtral](../mixtral/README.md) but made by combinding fine-tuned versions of
Phi-2.[^1][^2]
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
### Run
```
python generate.py \
--model mlabonne/phixtral-4x2_8 \
--prompt "write a quick sort in Python"
```
Run `python generate.py --help` to see all the options.
[^1]: For more details on Phixtral, see the [Hugging Face repo](https://huggingface.co/mlabonne/phixtral-4x2_8).
[^2]: For more details on Phi-2 see Microsoft's [blog post](
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2).

91
llms/phixtral/generate.py Normal file
View File

@ -0,0 +1,91 @@
# Copyright © 2023 Apple Inc.
import argparse
import time
import mlx.core as mx
import phixtral
import transformers
def generate(
model: phixtral.Model,
tokenizer: transformers.AutoTokenizer,
prompt: str,
max_tokens: int,
temp: float = 0.0,
):
print("[INFO] Generating with Phixtral...", flush=True)
print(prompt, end="", flush=True)
prompt = tokenizer(
prompt,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
prompt = mx.array(prompt)
tic = time.time()
tokens = []
skip = 0
for token, n in zip(
phixtral.generate(prompt, model, temp),
range(max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
# if (n + 1) % 10 == 0:
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="inference script")
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = phixtral.load(args.model)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)

262
llms/phixtral/phixtral.py Normal file
View File

@ -0,0 +1,262 @@
import glob
import inspect
import json
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
max_sequence_length: int = 2048
num_vocab: int = 51200
model_dim: int = 2560
num_heads: int = 32
num_layers: int = 32
rotary_dim: int = 32
num_experts_per_tok: int = 2
num_local_experts: int = 4
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(rotary_dim, traditional=False)
self.Wqkv = nn.Linear(dims, 3 * dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x, mask=None, cache=None):
qkv = self.Wqkv(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class MOE(nn.Module):
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)]
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
def __call__(self, x) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
if ne < self.num_experts:
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
else:
inds = mx.broadcast_to(mx.arange(ne), gates.shape)
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
yc = mx.concatenate(y)
return yc.reshape(orig_shape)
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.model_dim
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims)
self.moe = MOE(config, dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h, cache = self.mixer(h, mask, cache)
ff_h = self.moe(h)
return attn_h + ff_h + x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embd = Embd(config)
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
def __call__(self, x, mask, cache):
x = self.embd(x)
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
return x, cache
class Embd(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
def __call__(self, x):
return self.wte(x)
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.ln = LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs):
return self.linear(self.ln(inputs))
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
model_args = ModelArgs.from_dict(config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
return model, tokenizer

View File

@ -0,0 +1,7 @@
einops
hf_transfer
huggingface_hub
mlx
numpy
torch
transformers>=4.35

View File

@ -81,7 +81,7 @@ To fine-tune a model use:
```
python lora.py --model <path_to_model> \
--train \
--iters 600 \
--iters 600
```
If `--model` points to a quantized model, then the training will use QLoRA,
@ -100,7 +100,7 @@ To compute test set perplexity use:
```
python lora.py --model <path_to_model> \
--adapter-file <path_to_adapters.npz> \
--test \
--test
```
### Generate
@ -114,7 +114,7 @@ python lora.py --model <path_to_model> \
--prompt "table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A: " \
A: "
```
## Results
@ -211,7 +211,7 @@ python lora.py \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--lora-layers 4 \
--lora-layers 4
```
The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.