diff --git a/llms/README.md b/llms/README.md index 995ce1d3..fd33e98a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -84,7 +84,7 @@ You can upload new models to Hugging Face by specifying `--upload-repo` to python -m mlx_lm.convert \ --hf-path mistralai/Mistral-7B-v0.1 \ -q \ - --upload-repo mlx-community/my-4bit-mistral \ + --upload-repo mlx-community/my-4bit-mistral ``` ### Supported Models diff --git a/llms/phixtral/README.md b/llms/phixtral/README.md new file mode 100644 index 00000000..93524c0a --- /dev/null +++ b/llms/phixtral/README.md @@ -0,0 +1,28 @@ +# Phixtral + +Phixtral is a Mixture of Experts (MoE) architecture inspired by +[Mixtral](../mixtral/README.md) but made by combinding fine-tuned versions of +Phi-2.[^1][^2] + +### Setup + +Install the dependencies: + +``` +pip install -r requirements.txt +``` + +### Run + +``` +python generate.py \ + --model mlabonne/phixtral-4x2_8 \ + --prompt "write a quick sort in Python" +``` + +Run `python generate.py --help` to see all the options. + +[^1]: For more details on Phixtral, see the [Hugging Face repo](https://huggingface.co/mlabonne/phixtral-4x2_8). +[^2]: For more details on Phi-2 see Microsoft's [blog post]( +https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) +and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2). diff --git a/llms/phixtral/generate.py b/llms/phixtral/generate.py new file mode 100644 index 00000000..e1767e34 --- /dev/null +++ b/llms/phixtral/generate.py @@ -0,0 +1,91 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import time + +import mlx.core as mx +import phixtral +import transformers + + +def generate( + model: phixtral.Model, + tokenizer: transformers.AutoTokenizer, + prompt: str, + max_tokens: int, + temp: float = 0.0, +): + print("[INFO] Generating with Phixtral...", flush=True) + print(prompt, end="", flush=True) + prompt = tokenizer( + prompt, + return_tensors="np", + return_attention_mask=False, + )[ + "input_ids" + ][0] + prompt = mx.array(prompt) + + tic = time.time() + tokens = [] + skip = 0 + for token, n in zip( + phixtral.generate(prompt, model, temp), + range(max_tokens), + ): + if token == tokenizer.eos_token_id: + break + + if n == 0: + prompt_time = time.time() - tic + tic = time.time() + + tokens.append(token.item()) + # if (n + 1) % 10 == 0: + s = tokenizer.decode(tokens) + print(s[skip:], end="", flush=True) + skip = len(s) + print(tokenizer.decode(tokens)[skip:], flush=True) + gen_time = time.time() - tic + print("=" * 10) + if len(tokens) == 0: + print("No tokens generated for this prompt") + return + prompt_tps = prompt.size / prompt_time + gen_tps = (len(tokens) - 1) / gen_time + print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {gen_tps:.3f} tokens-per-sec") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="inference script") + parser.add_argument( + "--model", + type=str, + default="mlx_model", + help="The path to the local model directory or Hugging Face repo.", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="Write a detailed analogy between mathematics and a lighthouse.", + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + + args = parser.parse_args() + mx.random.seed(args.seed) + model, tokenizer = phixtral.load(args.model) + generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) diff --git a/llms/phixtral/phixtral.py b/llms/phixtral/phixtral.py new file mode 100644 index 00000000..6774bc1f --- /dev/null +++ b/llms/phixtral/phixtral.py @@ -0,0 +1,262 @@ +import glob +import inspect +import json +import math +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +from huggingface_hub import snapshot_download +from mlx.utils import tree_unflatten +from transformers import AutoTokenizer + + +@dataclass +class ModelArgs: + max_sequence_length: int = 2048 + num_vocab: int = 51200 + model_dim: int = 2560 + num_heads: int = 32 + num_layers: int = 32 + rotary_dim: int = 32 + num_experts_per_tok: int = 2 + num_local_experts: int = 4 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + + +class RoPEAttention(nn.Module): + def __init__(self, dims: int, num_heads: int, rotary_dim: int): + super().__init__() + + self.num_heads = num_heads + + self.rope = nn.RoPE(rotary_dim, traditional=False) + self.Wqkv = nn.Linear(dims, 3 * dims) + self.out_proj = nn.Linear(dims, dims) + + def __call__(self, x, mask=None, cache=None): + qkv = self.Wqkv(x) + queries, keys, values = mx.split(qkv, 3, axis=-1) + + # Extract some shapes + num_heads = self.num_heads + B, L, D = queries.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + + # Add RoPE to the queries and keys and combine them with the cache + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + queries = queries.astype(mx.float32) + keys = keys.astype(mx.float32) + + # Finally perform the attention computation + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + + scores = mx.softmax(scores, axis=-1).astype(values.dtype) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, dim) + self.act = nn.GELU(approx="precise") + + def __call__(self, x) -> mx.array: + return self.fc2(self.act(self.fc1(x))) + + +class MOE(nn.Module): + def __init__(self, args: ModelArgs, dim: int, hidden_dim: int): + super().__init__() + self.dim = dim + self.hidden_dim = hidden_dim + self.num_experts = args.num_local_experts + self.num_experts_per_tok = args.num_experts_per_tok + self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)] + self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False) + + def __call__(self, x) -> mx.array: + ne = self.num_experts_per_tok + orig_shape = x.shape + x = x.reshape(-1, x.shape[-1]) + + gates = self.gate(x) + if ne < self.num_experts: + inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] + else: + inds = mx.broadcast_to(mx.arange(ne), gates.shape) + + scores = mx.softmax( + mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), + axis=-1, + ).astype(gates.dtype) + + y = [] + for xt, st, it in zip(x, scores, inds.tolist()): + yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1) + yt = (yt * st).sum(axis=-1) + y.append(yt[None, :]) + yc = mx.concatenate(y) + + return yc.reshape(orig_shape) + + +class ParallelBlock(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + dims = config.model_dim + mlp_dims = dims * 4 + self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) + self.ln = LayerNorm(dims) + self.moe = MOE(config, dims, mlp_dims) + + def __call__(self, x, mask, cache): + h = self.ln(x) + attn_h, cache = self.mixer(h, mask, cache) + ff_h = self.moe(h) + return attn_h + ff_h + x, cache + + +class TransformerDecoder(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.embd = Embd(config) + self.h = [ParallelBlock(config) for i in range(config.num_layers)] + + def __call__(self, x, mask, cache): + x = self.embd(x) + if cache is None: + cache = [None] * len(self.h) + + for e, layer in enumerate(self.h): + x, cache[e] = layer(x, mask, cache[e]) + return x, cache + + +class Embd(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.wte = nn.Embedding(config.num_vocab, config.model_dim) + + def __call__(self, x): + return self.wte(x) + + +class OutputHead(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.ln = LayerNorm(config.model_dim) + self.linear = nn.Linear(config.model_dim, config.num_vocab) + + def __call__(self, inputs): + return self.linear(self.ln(inputs)) + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.transformer = TransformerDecoder(config) + self.lm_head = OutputHead(config) + + def __call__( + self, + x: mx.array, + mask: mx.array = None, + cache: mx.array = None, + ) -> tuple[mx.array, mx.array]: + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) + + y, cache = self.transformer(x, mask, cache) + return self.lm_head(y), cache + + +def generate(prompt: mx.array, model: Model, temp: float = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + y = prompt + cache = None + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits) + yield y + + +def load(path_or_hf_repo: str): + # If the path exists, it will try to load model form it + # otherwise download and cache from the hf_repo and cache + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + ) + + with open(model_path / "config.json", "r") as f: + config = json.loads(f.read()) + quantization = config.get("quantization", None) + model_args = ModelArgs.from_dict(config) + + weight_files = glob.glob(str(model_path / "*.safetensors")) + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + model = Model(model_args) + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) + + model.load_weights(list(weights.items())) + + mx.eval(model.parameters()) + tokenizer = AutoTokenizer.from_pretrained( + model_path, + ) + return model, tokenizer diff --git a/llms/phixtral/requirements.txt b/llms/phixtral/requirements.txt new file mode 100644 index 00000000..016af3ae --- /dev/null +++ b/llms/phixtral/requirements.txt @@ -0,0 +1,7 @@ +einops +hf_transfer +huggingface_hub +mlx +numpy +torch +transformers>=4.35 diff --git a/lora/README.md b/lora/README.md index d9de523c..9c379679 100644 --- a/lora/README.md +++ b/lora/README.md @@ -81,7 +81,7 @@ To fine-tune a model use: ``` python lora.py --model \ --train \ - --iters 600 \ + --iters 600 ``` If `--model` points to a quantized model, then the training will use QLoRA, @@ -100,7 +100,7 @@ To compute test set perplexity use: ``` python lora.py --model \ --adapter-file \ - --test \ + --test ``` ### Generate @@ -114,7 +114,7 @@ python lora.py --model \ --prompt "table: 1-10015132-16 columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team Q: What is terrence ross' nationality -A: " \ +A: " ``` ## Results @@ -211,7 +211,7 @@ python lora.py \ --model mistralai/Mistral-7B-v0.1 \ --train \ --batch-size 1 \ - --lora-layers 4 \ + --lora-layers 4 ``` The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.