From f0c57c1361b4084976f4e39df3a6bc82d44ef40d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 12 Dec 2023 12:48:15 -0800 Subject: [PATCH] llama v2 with sharded weights --- llama/README.md | 25 +++-- llama/convert.py | 78 +++++++++------- llama/llama.py | 206 ++++++++++++++++++++++++++--------------- llama/requirements.txt | 1 + mixtral/README.md | 2 +- 5 files changed, 189 insertions(+), 123 deletions(-) diff --git a/llama/README.md b/llama/README.md index b9f487dd..da4e85f3 100644 --- a/llama/README.md +++ b/llama/README.md @@ -1,8 +1,9 @@ -# LLaMA +# Llama -An example of generating text with LLaMA using MLX. +An example of generating text with Llama (1 or 2) using MLX. -LLaMA is a set of open source language models from Meta AI Research[^1] ranging from 7B to 65B parameters. +Llama is a set of open source language models from Meta AI Research[^1][^2] +ranging from 7B to 70B parameters. ### Setup @@ -14,27 +15,31 @@ pip install -r requirements.txt Next, download and convert the model. If you do not have access to the model weights you will need to [request -access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) +access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) from Meta. - -Alternatively, you can also download a select converted checkpoints from the [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging Face and skip the conversion step. +Alternatively, you can also download a select converted checkpoints from the +[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging +Face and skip the conversion step. Convert the weights with: ``` -python convert.py +python convert.py --model_path ``` +The conversion script will save the converted weights in the same location. + ### Run Once you've converted the weights to MLX format, you can interact with the -LLaMA model: +LlaMA model: ``` -python llama.py "hello" +python llama.py "hello" ``` Run `python llama.py --help` for more details. -[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. +[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. +[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/) diff --git a/llama/convert.py b/llama/convert.py index 69168493..89ce8a36 100644 --- a/llama/convert.py +++ b/llama/convert.py @@ -1,53 +1,59 @@ # Copyright © 2023 Apple Inc. import argparse -from itertools import starmap +import collections +import glob +from pathlib import Path import numpy as np import torch +SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] +SHARD_SECOND = ["tok_embeddings", "wo", "w2"] +SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND) -def map_torch_to_mlx(key, value): - if "tok_embedding" in key: - key = "embedding.weight" - elif "norm" in key: - key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") +def shard_key(k): + keys = k.split(".") + if len(keys) < 2: + return None + return keys[-2] - elif "wq" in key or "wk" in key or "wv" in key or "wo" in key: - key = key.replace("wq", "query_proj") - key = key.replace("wk", "key_proj") - key = key.replace("wv", "value_proj") - key = key.replace("wo", "out_proj") - elif "w1" in key or "w2" in key or "w3" in key: - # The FFN is a separate submodule in PyTorch - key = key.replace("feed_forward.w1", "linear1") - key = key.replace("feed_forward.w3", "linear2") - key = key.replace("feed_forward.w2", "linear3") - - elif "output" in key: - key = key.replace("output", "out_proj") - - elif "rope" in key: - return None, None - - return ( - key, - value.numpy() - if value.dtype != torch.bfloat16 - else value.to(torch.float32).numpy(), - ) +def unshard(k, v): + wn = shard_key(k) + if wn not in SHARD_WEIGHTS: + return v + elif wn in SHARD_FIRST: + axis = 0 + elif wn in SHARD_SECOND: + axis = 1 + else: + raise ValueError("Invalid weight name") + return np.concatenate(v, axis=axis) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") - parser.add_argument("torch_weights") - parser.add_argument("output_file") + parser.add_argument( + "--model_path", + help="Path to the Torch model. The MLX weights will also be saved there.", + ) args = parser.parse_args() - state = torch.load(args.torch_weights, map_location=torch.device('cpu')) - np.savez( - args.output_file, - **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} - ) + model_path = Path(args.model_path) + torch_files = glob.glob(str(model_path / "consolidated.*.pth")) + weights = collections.defaultdict(list) + for wf in torch_files: + state = torch.load(wf, map_location=torch.device("cpu")) + for k, v in state.items(): + v = v.to(torch.float16).numpy() + if shard_key(k) in SHARD_WEIGHTS: + weights[k].append(v) + else: + weights[k] = v + + out_file = str(model_path / "weights.npz") + for k, v in weights.items(): + weights[k] = unshard(k, v) + np.savez(out_file, **weights) diff --git a/llama/llama.py b/llama/llama.py index c18728ff..db9c8db3 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -1,8 +1,10 @@ # Copyright © 2023 Apple Inc. import argparse -import math -import numpy as np +from dataclasses import dataclass +import json +from pathlib import Path +from typing import Optional, Tuple, List from sentencepiece import SentencePieceProcessor import time @@ -11,33 +13,71 @@ import mlx.nn as nn from mlx.utils import tree_unflatten -class LlamaAttention(nn.Module): - def __init__(self, dims: int, num_heads: int): +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps - self.num_heads = num_heads + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - self.rope = nn.RoPE(dims // num_heads, traditional=True) - self.query_proj = nn.Linear(dims, dims, bias=False) - self.key_proj = nn.Linear(dims, dims, bias=False) - self.value_proj = nn.Linear(dims, dims, bias=False) - self.out_proj = nn.Linear(dims, dims, bias=False) + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output - def __call__(self, queries, keys, values, mask=None, cache=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) - # Extract some shapes - num_heads = self.num_heads - B, L, D = queries.shape +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.rope = nn.RoPE(args.head_dim, traditional=True) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + keys, values = map(repeat, (keys, values)) - # Add RoPE to the queries and keys and combine them with the cache if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -48,86 +88,87 @@ class LlamaAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: - scores = scores + mask - scores = mx.softmax(scores, axis=-1) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - # Note that we return the keys and values to possibly be used as a cache - return self.out_proj(values_hat), (keys, values) + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.wo(output), (keys, values) -class LlamaEncoderLayer(nn.Module): - def __init__(self, dims: int, mlp_dims: int, num_heads: int): +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): super().__init__() - self.attention = LlamaAttention(dims, num_heads) + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) - self.norm1 = nn.RMSNorm(dims) - self.norm2 = nn.RMSNorm(dims) + def __call__(self, x) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) - self.linear1 = nn.Linear(dims, mlp_dims, bias=False) - self.linear2 = nn.Linear(dims, mlp_dims, bias=False) - self.linear3 = nn.Linear(mlp_dims, dims, bias=False) - def __call__(self, x, mask=None, cache=None): - y = self.norm1(x) - y, cache = self.attention(y, y, y, mask, cache) - x = x + y +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = FeedForward(args=args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args - y = self.norm2(x) - a = self.linear1(y) - b = self.linear2(y) - y = a * mx.sigmoid(a) * b - y = self.linear3(y) - x = x + y - - return x, cache + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attention(self.attention_norm(x), mask, cache) + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + return out, cache class Llama(nn.Module): - def __init__( - self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int - ): + def __init__(self, args: ModelArgs): super().__init__() - - self.embedding = nn.Embedding(vocab_size, dims) - self.layers = [ - LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers) - ] - self.norm = nn.RMSNorm(dims) - self.out_proj = nn.Linear(dims, vocab_size, bias=False) + self.args = args + self.vocab_size = args.vocab_size + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) def __call__(self, x): mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(self.embedding.weight.dtype) + mask = mask.astype(self.tok_embeddings.weight.dtype) - x = self.embedding(x) + x = self.tok_embeddings(x) for l in self.layers: x, _ = l(x, mask) x = self.norm(x) - return self.out_proj(x) + return self.output(x) def generate(self, x, temp=1.0): cache = [] # Make an additive causal mask. We will need that to process the prompt. mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(self.embedding.weight.dtype) + mask = mask.astype(self.tok_embeddings.weight.dtype) # First we process the prompt x the same was as in __call__ but # save the caches in cache - x = self.embedding(x) + x = self.tok_embeddings(x) for l in self.layers: x, c = l(x, mask=mask) # We store the per layer cache in a simple python list cache.append(c) x = self.norm(x) # We only care about the last logits that generate the next token - y = self.out_proj(x[:, -1]) + y = self.output(x[:, -1]) y = mx.random.categorical(y * (1 / temp)) # y now has size [1] @@ -145,14 +186,14 @@ class Llama(nn.Module): # dimension of 1 x = y[:, None] - x = self.embedding(x) + x = self.tok_embeddings(x) for i in range(len(cache)): # We are overwriting the arrays in the cache list. When # the computation will happen, MLX will be discarding the # old cache the moment it is not needed anymore. x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) x = self.norm(x) - y = self.out_proj(x[:, -1]) + y = self.output(x[:, -1]) y = mx.random.categorical(y * (1 / temp)) yield y @@ -261,20 +302,33 @@ def few_shot_generate(args): def load_model(model_path): - weights = mx.load(model_path) - mlp_dims, dims = weights["layers.0.linear1.weight"].shape - num_heads = dims // 128 - num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1 - vocab_size = weights["out_proj.weight"].shape[-1] - model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads) + model_path = Path(model_path) + weights = mx.load(str(model_path / "weights.npz")) + with open(model_path / "params.json", "r") as f: + config = json.loads(f.read()) + n_heads = config["n_heads"] + if "n_kv_heads" not in config: + config["n_kv_heads"] = n_heads + if "head_dim" not in config: + config["head_dim"] = config["dim"] // n_heads + if "hidden_dim" not in config: + config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] + if config.get("vocab_size", -1) < 0: + config["vocab_size"] = weights["output.weight"].shape[-1] + unused = ["multiple_of", "ffn_dim_multiplie"] + for k in unused: + if k in config: + config.pop(k) + model = Llama(ModelArgs(**config)) model.update(tree_unflatten(list(weights.items()))) - mx.eval(model.parameters()) return model if __name__ == "__main__": parser = argparse.ArgumentParser(description="Llama inference script") - parser.add_argument("model", help="The model file containing MLX weights") + parser.add_argument( + "model", help="Path to the model directory containing the MLX weights" + ) parser.add_argument("tokenizer", help="The sentencepiece tokenizer") parser.add_argument("prompt", help="The message to be processed by the model") parser.add_argument( diff --git a/llama/requirements.txt b/llama/requirements.txt index c036fa59..7111f1d4 100644 --- a/llama/requirements.txt +++ b/llama/requirements.txt @@ -1,2 +1,3 @@ +mlx sentencepiece torch diff --git a/mixtral/README.md b/mixtral/README.md index 23de1430..494e8107 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -46,7 +46,7 @@ rm mixtral-8x7b-32kseqlen/*.pth* As easy as: ``` -python mixtral.py --model_path mixtral mixtral-8x7b-32kseqlen/ +python mixtral.py --model_path mixtral-8x7b-32kseqlen/ ``` [^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.