Merge pull request #82 from ml-explore/llamav2

llama v2 with sharded weights
This commit is contained in:
Awni Hannun 2023-12-12 17:08:24 -08:00 committed by GitHub
commit a614e951c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 201 additions and 133 deletions

View File

@ -25,7 +25,7 @@ def run(bert_model: str):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run the BERT model using HuggingFace Transformers." description="Run the BERT model using Hugging Face Transformers."
) )
parser.add_argument( parser.add_argument(
"--bert-model", "--bert-model",

View File

@ -1,8 +1,9 @@
# LLaMA # Llama
An example of generating text with LLaMA using MLX. An example of generating text with Llama (1 or 2) using MLX.
LLaMA is a set of open source language models from Meta AI Research[^1] ranging from 7B to 65B parameters. Llama is a set of open source language models from Meta AI Research[^1][^2]
ranging from 7B to 70B parameters.
### Setup ### Setup
@ -13,28 +14,34 @@ pip install -r requirements.txt
``` ```
Next, download and convert the model. If you do not have access to the model Next, download and convert the model. If you do not have access to the model
weights you will need to [request weights you will need to request access from Meta:
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
from Meta. - [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
Alternatively, you can also download a select converted checkpoints from the [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging Face and skip the conversion step. Alternatively, you can also download a select converted checkpoints from the
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
Face and skip the conversion step.
Convert the weights with: Convert the weights with:
``` ```
python convert.py <path_to_torch_weights> <path_to_mlx_llama_weights.npz> python convert.py --model_path <path_to_torch_model>
``` ```
The conversion script will save the converted weights in the same location.
### Run ### Run
Once you've converted the weights to MLX format, you can interact with the Once you've converted the weights to MLX format, you can interact with the
LLaMA model: LlaMA model:
``` ```
python llama.py <path_to_mlx_llama_weights.npz> <path_to_tokenizer.model> "hello" python llama.py <path_to_model> <path_to_tokenizer.model> "hello"
``` ```
Run `python llama.py --help` for more details. Run `python llama.py --help` for more details.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. [^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)

View File

@ -1,53 +1,59 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
from itertools import starmap import collections
import glob
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key: def shard_key(k):
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") keys = k.split(".")
if len(keys) < 2:
return None
return keys[-2]
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key: def unshard(k, v):
# The FFN is a separate submodule in PyTorch wn = shard_key(k)
key = key.replace("feed_forward.w1", "linear1") if wn not in SHARD_WEIGHTS:
key = key.replace("feed_forward.w3", "linear2") return v
key = key.replace("feed_forward.w2", "linear3") elif wn in SHARD_FIRST:
axis = 0
elif "output" in key: elif wn in SHARD_SECOND:
key = key.replace("output", "out_proj") axis = 1
else:
elif "rope" in key: raise ValueError("Invalid weight name")
return None, None return np.concatenate(v, axis=axis)
return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights") parser.add_argument(
parser.add_argument("output_file") "--model_path",
help="Path to the Torch model. The MLX weights will also be saved there.",
)
args = parser.parse_args() args = parser.parse_args()
state = torch.load(args.torch_weights, map_location=torch.device('cpu')) model_path = Path(args.model_path)
np.savez( torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
args.output_file, weights = collections.defaultdict(list)
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} for wf in torch_files:
) state = torch.load(wf, map_location=torch.device("cpu"))
for k, v in state.items():
v = v.to(torch.float16).numpy()
if shard_key(k) in SHARD_WEIGHTS:
weights[k].append(v)
else:
weights[k] = v
out_file = str(model_path / "weights.npz")
for k, v in weights.items():
weights[k] = unshard(k, v)
np.savez(out_file, **weights)

View File

@ -1,8 +1,10 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import math from dataclasses import dataclass
import numpy as np import json
from pathlib import Path
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
import time import time
@ -11,33 +13,71 @@ import mlx.nn as nn
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
class LlamaAttention(nn.Module): @dataclass
def __init__(self, dims: int, num_heads: int): class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__() super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
self.num_heads = num_heads def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
self.rope = nn.RoPE(dims // num_heads, traditional=True) def __call__(self, x):
self.query_proj = nn.Linear(dims, dims, bias=False) output = self._norm(x.astype(mx.float32)).astype(x.dtype)
self.key_proj = nn.Linear(dims, dims, bias=False) return self.weight * output
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# Extract some shapes class Attention(nn.Module):
num_heads = self.num_heads def __init__(self, args: ModelArgs):
B, L, D = queries.shape super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation # Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
keys, values = map(repeat, (keys, values))
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -48,86 +88,87 @@ class LlamaAttention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
# Finally perform the attention computation scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None: if mask is not None:
scores = scores + mask scores += mask
scores = mx.softmax(scores, axis=-1) scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
# Note that we return the keys and values to possibly be used as a cache
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module): class FeedForward(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.attention = LlamaAttention(dims, num_heads) self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.norm1 = nn.RMSNorm(dims) def __call__(self, x) -> mx.array:
self.norm2 = nn.RMSNorm(dims) return self.w2(nn.silu(self.w1(x)) * self.w3(x))
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask=None, cache=None): class TransformerBlock(nn.Module):
y = self.norm1(x) def __init__(self, args: ModelArgs):
y, cache = self.attention(y, y, y, mask, cache) super().__init__()
x = x + y self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
y = self.norm2(x) def __call__(
a = self.linear1(y) self,
b = self.linear2(y) x: mx.array,
y = a * mx.sigmoid(a) * b mask: Optional[mx.array] = None,
y = self.linear3(y) cache: Optional[Tuple[mx.array, mx.array]] = None,
x = x + y ) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
return x, cache h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class Llama(nn.Module): class Llama(nn.Module):
def __init__( def __init__(self, args: ModelArgs):
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
):
super().__init__() super().__init__()
self.args = args
self.embedding = nn.Embedding(vocab_size, dims) self.vocab_size = args.vocab_size
self.layers = [ self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers) self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
] self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.norm = nn.RMSNorm(dims) self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
def __call__(self, x): def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype) mask = mask.astype(self.tok_embeddings.weight.dtype)
x = self.embedding(x) x = self.tok_embeddings(x)
for l in self.layers: for l in self.layers:
x, _ = l(x, mask) x, _ = l(x, mask)
x = self.norm(x) x = self.norm(x)
return self.out_proj(x) return self.output(x)
def generate(self, x, temp=1.0): def generate(self, x, temp=1.0):
cache = [] cache = []
# Make an additive causal mask. We will need that to process the prompt. # Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype) mask = mask.astype(self.tok_embeddings.weight.dtype)
# First we process the prompt x the same was as in __call__ but # First we process the prompt x the same was as in __call__ but
# save the caches in cache # save the caches in cache
x = self.embedding(x) x = self.tok_embeddings(x)
for l in self.layers: for l in self.layers:
x, c = l(x, mask=mask) x, c = l(x, mask=mask)
# We store the per layer cache in a simple python list # We store the per layer cache in a simple python list
cache.append(c) cache.append(c)
x = self.norm(x) x = self.norm(x)
# We only care about the last logits that generate the next token # We only care about the last logits that generate the next token
y = self.out_proj(x[:, -1]) y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp)) y = mx.random.categorical(y * (1 / temp))
# y now has size [1] # y now has size [1]
@ -145,14 +186,14 @@ class Llama(nn.Module):
# dimension of 1 # dimension of 1
x = y[:, None] x = y[:, None]
x = self.embedding(x) x = self.tok_embeddings(x)
for i in range(len(cache)): for i in range(len(cache)):
# We are overwriting the arrays in the cache list. When # We are overwriting the arrays in the cache list. When
# the computation will happen, MLX will be discarding the # the computation will happen, MLX will be discarding the
# old cache the moment it is not needed anymore. # old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x) x = self.norm(x)
y = self.out_proj(x[:, -1]) y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp)) y = mx.random.categorical(y * (1 / temp))
yield y yield y
@ -261,20 +302,33 @@ def few_shot_generate(args):
def load_model(model_path): def load_model(model_path):
weights = mx.load(model_path) model_path = Path(model_path)
mlp_dims, dims = weights["layers.0.linear1.weight"].shape weights = mx.load(str(model_path / "weights.npz"))
num_heads = dims // 128 with open(model_path / "params.json", "r") as f:
num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1 config = json.loads(f.read())
vocab_size = weights["out_proj.weight"].shape[-1] n_heads = config["n_heads"]
model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads) if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
unused = ["multiple_of", "ffn_dim_multiplie"]
for k in unused:
if k in config:
config.pop(k)
model = Llama(ModelArgs(**config))
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
mx.eval(model.parameters())
return model return model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script") parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument("model", help="The model file containing MLX weights") parser.add_argument(
"model", help="Path to the model directory containing the MLX weights"
)
parser.add_argument("tokenizer", help="The sentencepiece tokenizer") parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument("prompt", help="The message to be processed by the model") parser.add_argument("prompt", help="The message to be processed by the model")
parser.add_argument( parser.add_argument(

View File

@ -1,2 +1,3 @@
mlx
sentencepiece sentencepiece
torch torch

View File

@ -14,7 +14,7 @@ For example with Homebrew:
brew install git-lfs brew install git-lfs
``` ```
Download the models from HuggingFace: Download the models from Hugging Face:
``` ```
git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen
@ -46,7 +46,7 @@ rm mixtral-8x7b-32kseqlen/*.pth*
As easy as: As easy as:
``` ```
python mixtral.py --model_path mixtral mixtral-8x7b-32kseqlen/ python mixtral.py --model_path mixtral-8x7b-32kseqlen/
``` ```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. [^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.

View File

@ -1,9 +1,9 @@
Stable Diffusion Stable Diffusion
================ ================
Stable Diffusion in MLX. The implementation was ported from Huggingface's Stable Diffusion in MLX. The implementation was ported from Hugging Face's
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching [diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
and using the weights available on the Huggingface Hub by Stability AI at and using the weights available on the Hugging Face Hub by Stability AI at
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1). [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
![out](generated-mlx.png) ![out](generated-mlx.png)

View File

@ -169,7 +169,7 @@ def _check_key(key: str, part: str):
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False): def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion UNet from Huggingface Hub.""" """Load the stable diffusion UNet from Hugging Face Hub."""
_check_key(key, "load_unet") _check_key(key, "load_unet")
# Download the config and create the model # Download the config and create the model
@ -199,7 +199,7 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False): def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion text encoder from Huggingface Hub.""" """Load the stable diffusion text encoder from Hugging Face Hub."""
_check_key(key, "load_text_encoder") _check_key(key, "load_text_encoder")
# Download the config and create the model # Download the config and create the model
@ -226,7 +226,7 @@ def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False): def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion autoencoder from Huggingface Hub.""" """Load the stable diffusion autoencoder from Hugging Face Hub."""
_check_key(key, "load_autoencoder") _check_key(key, "load_autoencoder")
# Download the config and create the model # Download the config and create the model
@ -255,7 +255,7 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
def load_diffusion_config(key: str = _DEFAULT_MODEL): def load_diffusion_config(key: str = _DEFAULT_MODEL):
"""Load the stable diffusion config from Huggingface Hub.""" """Load the stable diffusion config from Hugging Face Hub."""
_check_key(key, "load_diffusion_config") _check_key(key, "load_diffusion_config")
diffusion_config = _MODELS[key]["diffusion_config"] diffusion_config = _MODELS[key]["diffusion_config"]

View File

@ -81,7 +81,7 @@ class Tokenizer:
if isinstance(text, list): if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text] return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Huggingface does # Lower case cleanup and split according to self.pat. Hugging Face does
# a much more thorough job here but this should suffice for 95% of # a much more thorough job here but this should suffice for 95% of
# cases. # cases.
clean_text = regex.sub(r"\s+", " ", text.lower()) clean_text = regex.sub(r"\s+", " ", text.lower())