diff --git a/llama/README.md b/llama/README.md index 220d1b16..39c0267c 100644 --- a/llama/README.md +++ b/llama/README.md @@ -3,8 +3,8 @@ An example of generating text with Llama (1 or 2) using MLX. Llama is a set of open source language models from Meta AI Research[^1][^2] -ranging from 7B to 70B parameters. This example also supports Llama Chat and -Code Llama. +ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat +and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3] ### Setup @@ -25,10 +25,19 @@ Alternatively, you can also download a select converted checkpoints from the [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging Face and skip the conversion step. +You can download the TinyLlama models directly from [Hugging +Face](https://huggingface.co/TinyLlama). + Convert the weights with: ``` -python convert.py --model_path +python convert.py --model-path +``` + +For TinyLlama use + +``` +python convert.py --model-path --model-name tiny_llama ``` The conversion script will save the converted weights in the same location. @@ -39,10 +48,11 @@ Once you've converted the weights to MLX format, you can interact with the LlaMA model: ``` -python llama.py "hello" +python llama.py --prompt "hello" ``` Run `python llama.py --help` for more details. [^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. [^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/) +[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file) diff --git a/llama/convert.py b/llama/convert.py index 89ce8a36..957f3d22 100644 --- a/llama/convert.py +++ b/llama/convert.py @@ -3,45 +3,35 @@ import argparse import collections import glob -from pathlib import Path - +import json import numpy as np +from pathlib import Path import torch -SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] -SHARD_SECOND = ["tok_embeddings", "wo", "w2"] -SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND) +def llama(model_path): + SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] + SHARD_SECOND = ["tok_embeddings", "wo", "w2"] + SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND) -def shard_key(k): - keys = k.split(".") - if len(keys) < 2: - return None - return keys[-2] + def shard_key(k): + keys = k.split(".") + if len(keys) < 2: + return None + return keys[-2] + def unshard(k, v): + wn = shard_key(k) + if wn not in SHARD_WEIGHTS: + return v + elif wn in SHARD_FIRST: + axis = 0 + elif wn in SHARD_SECOND: + axis = 1 + else: + raise ValueError("Invalid weight name") + return np.concatenate(v, axis=axis) -def unshard(k, v): - wn = shard_key(k) - if wn not in SHARD_WEIGHTS: - return v - elif wn in SHARD_FIRST: - axis = 0 - elif wn in SHARD_SECOND: - axis = 1 - else: - raise ValueError("Invalid weight name") - return np.concatenate(v, axis=axis) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") - parser.add_argument( - "--model_path", - help="Path to the Torch model. The MLX weights will also be saved there.", - ) - args = parser.parse_args() - - model_path = Path(args.model_path) torch_files = glob.glob(str(model_path / "consolidated.*.pth")) weights = collections.defaultdict(list) for wf in torch_files: @@ -53,7 +43,96 @@ if __name__ == "__main__": else: weights[k] = v - out_file = str(model_path / "weights.npz") for k, v in weights.items(): weights[k] = unshard(k, v) - np.savez(out_file, **weights) + return weights, None + + +def tiny_llama(model_path): + try: + import transformers + except ImportError as e: + print("The transformers package must be installed for this model conversion:") + print("pip install transformers") + import sys + + sys.exit(0) + + model = transformers.AutoModelForCausalLM.from_pretrained( + str(model_path) + ).state_dict() + config = transformers.AutoConfig.from_pretrained(model_path) + + # things to change + # 1. there's no "model." in the weight names + model = {k.replace("model.", ""): v for k, v in model.items()} + + # 2. mlp is called feed_forward + model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()} + + # 3. up_proj, down_proj, gate_proj + model = {k.replace("down_proj", "w2"): v for k, v in model.items()} + model = {k.replace("up_proj", "w3"): v for k, v in model.items()} + model = {k.replace("gate_proj", "w1"): v for k, v in model.items()} + + # 4. layernorms + model = { + k.replace("input_layernorm", "attention_norm"): v for k, v in model.items() + } + model = { + k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items() + } + + # 5. lm head + model = {k.replace("lm_head", "output"): v for k, v in model.items()} + + # 6. token emb + model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()} + + # 7. attention + model = {k.replace("self_attn", "attention"): v for k, v in model.items()} + model = {k.replace("q_proj", "wq"): v for k, v in model.items()} + model = {k.replace("k_proj", "wk"): v for k, v in model.items()} + model = {k.replace("v_proj", "wv"): v for k, v in model.items()} + model = {k.replace("o_proj", "wo"): v for k, v in model.items()} + + params = {} + params["dim"] = config.hidden_size + params["hidden_dim"] = config.intermediate_size + params["n_heads"] = config.num_attention_heads + if hasattr(config, "num_key_value_heads"): + params["n_kv_heads"] = config.num_key_value_heads + params["n_layers"] = config.num_hidden_layers + params["vocab_size"] = config.vocab_size + params["norm_eps"] = config.rms_norm_eps + params["rope_traditional"] = False + weights = {k: v.to(torch.float16).numpy() for k, v in model.items()} + + return weights, params + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") + parser.add_argument( + "--model-path", + help="Path to the model. The MLX weights will also be saved there.", + ) + parser.add_argument( + "--model-name", + help=( + "Name of the model to convert. Use 'llama' for models in the " + "Llama family distributed by Meta including Llama 1, Llama 2, " + "Coda Llama, and Llama chat." + ), + choices=["tiny_llama", "llama"], + default="llama", + ) + + args = parser.parse_args() + + model_path = Path(args.model_path) + weights, params = globals()[args.model_name](model_path) + np.savez(str(model_path / "weights.npz"), **weights) + if params is not None: + with open(model_path / "params.json", "w") as fid: + json.dump(params, fid, indent=4) diff --git a/llama/llama.py b/llama/llama.py index 5f169de4..af98685d 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -24,6 +24,7 @@ class ModelArgs: norm_eps: float vocab_size: int rope_theta: float + rope_traditional: bool = True class RMSNorm(nn.Module): @@ -77,7 +78,9 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta) + self.rope = RoPE( + args.head_dim, traditional=args.rope_traditional, base=args.rope_theta + ) def __call__( self, @@ -234,7 +237,7 @@ def generate(args): input("Press enter to start generation") print("------") - + print(args.prompt) x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) skip = 0 prompt_processing = None @@ -248,7 +251,7 @@ def generate(args): mx.eval(token) prompt_processing = toc("Prompt processing", start) - if len(tokens) >= args.num_tokens: + if len(tokens) >= args.max_tokens: break elif (len(tokens) % args.write_every) == 0: @@ -261,8 +264,7 @@ def generate(args): mx.eval(tokens) full_gen = toc("Full generation", start) s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], end="", flush=True) - print() + print(s[skip:], flush=True) print("------") print(prompt_processing) print(full_gen) @@ -354,14 +356,18 @@ if __name__ == "__main__": "model", help="Path to the model directory containing the MLX weights" ) parser.add_argument("tokenizer", help="The sentencepiece tokenizer") - parser.add_argument("prompt", help="The message to be processed by the model") + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="In the beginning the Universe was created.", + ) parser.add_argument( "--few-shot", action="store_true", help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", ) parser.add_argument( - "--num-tokens", "-n", type=int, default=100, help="How many tokens to generate" + "--max-tokens", "-m", type=int, default=100, help="How many tokens to generate" ) parser.add_argument( "--write-every", type=int, default=1, help="After how many tokens to detokenize"