From c8abd7906dd892005b9911b5eaff7f29c02dbf97 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 20 Dec 2023 21:25:25 -0800 Subject: [PATCH] llama / mistral conversion in good shape --- llms/llama/README.md | 14 ++++++----- llms/llama/convert.py | 56 ++++++++++++++++++++++++++++++----------- llms/llama/llama.py | 11 ++++---- llms/mistral/convert.py | 7 ++---- 4 files changed, 58 insertions(+), 30 deletions(-) diff --git a/llms/llama/README.md b/llms/llama/README.md index 1189d928..3d9a97f1 100644 --- a/llms/llama/README.md +++ b/llms/llama/README.md @@ -30,30 +30,32 @@ Face](https://huggingface.co/TinyLlama). Convert the weights with: ``` -python convert.py --model-path +python convert.py --torch-path ``` To generate a 4-bit quantized model use the `-q` flag: ``` -python convert.py --model-path -q +python convert.py --torch-path -q ``` For TinyLlama use ``` -python convert.py --model-path --model-name tiny_llama +python convert.py --torch-path --model-name tiny_llama ``` -The conversion script will save the converted weights in the same location. +By default, the conversion script will make the directory `mlx_model` and save +the converted `weights.npz`, `tokenizer.model`, and `config.json` there. + ### Run Once you've converted the weights to MLX format, you can interact with the -LlaMA model: +LlamA model: ``` -python llama.py --prompt "hello" +python llama.py --prompt "hello" ``` Run `python llama.py --help` for more details. diff --git a/llms/llama/convert.py b/llms/llama/convert.py index edcfe668..618c3070 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -5,6 +5,7 @@ import collections import copy import glob import json +import shutil from pathlib import Path import mlx.core as mx @@ -62,9 +63,7 @@ def tiny_llama(model_path): except ImportError as e: print("The transformers package must be installed for this model conversion:") print("pip install transformers") - import sys - - sys.exit(0) + exit(0) model = transformers.AutoModelForCausalLM.from_pretrained( str(model_path) @@ -119,7 +118,7 @@ def tiny_llama(model_path): return weights, params -def quantize(weights, config): +def quantize(weights, config, args): quantized_config = copy.deepcopy(config) # Load the model: @@ -129,10 +128,13 @@ def quantize(weights, config): model.update(tree_unflatten(list(weights.items()))) # Quantize the model: - nn.QuantizedLinear.quantize_module(model) + nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) # Update the config: - quantized_config["quantization"] = {"group_size": 64, "bits": 4} + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } quantized_weights = dict(tree_flatten(model.parameters())) return quantized_weights, quantized_config @@ -141,8 +143,15 @@ def quantize(weights, config): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser.add_argument( - "--model-path", - help="Path to the model. The MLX weights will also be saved there.", + "--torch-path", + type=str, + help="Path to the PyTorch model.", + ) + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="Path to save the MLX model.", ) parser.add_argument( "--model-name", @@ -157,21 +166,40 @@ if __name__ == "__main__": parser.add_argument( "-q", "--quantize", - help="Generate a 4-bit quantized model.", + help="Generate a quantized model.", action="store_true", ) + parser.add_argument( + "--q_group_size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q_bits", + help="Bits per weight for quantization.", + type=int, + default=4, + ) args = parser.parse_args() - model_path = Path(args.model_path) + torch_path = Path(args.torch_path) + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + print("[INFO] Loading") - weights, params = globals()[args.model_name](model_path) + weights, params = globals()[args.model_name](torch_path) params["model_type"] = "llama" if args.quantize: print("[INFO] Quantizing") - weights, params = quantize(weights, params) + weights, params = quantize(weights, params, args) print("[INFO] Saving") - np.savez(str(model_path / "weights.npz"), **weights) - with open(model_path / "config.json", "w") as fid: + shutil.copyfile( + str(torch_path / "tokenizer.model"), + str(mlx_path / "tokenizer.model"), + ) + np.savez(str(mlx_path / "weights.npz"), **weights) + with open(mlx_path / "config.json", "w") as fid: json.dump(params, fid, indent=4) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 324c210a..6a7352f3 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -360,15 +360,17 @@ def load_model(model_path): if quantization is not None: nn.QuantizedLinear.quantize_module(model, **quantization) model.update(tree_unflatten(list(weights.items()))) - return model + tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) + return model, tokenizer if __name__ == "__main__": parser = argparse.ArgumentParser(description="Llama inference script") parser.add_argument( - "model", help="Path to the model directory containing the MLX weights" + "--model-path", + help="Path to the model directory containing the MLX weights", + default="mlx_model", ) - parser.add_argument("tokenizer", help="The sentencepiece tokenizer") parser.add_argument( "--prompt", help="The message to be processed by the model. Ignored when --few-shot is provided.", @@ -393,9 +395,8 @@ if __name__ == "__main__": mx.random.seed(args.seed) - tokenizer = SentencePieceProcessor(model_file=args.tokenizer) print("[INFO] Loading model from disk.") - model = load_model(args.model) + model, tokenizer = load_model(args.model_path) if args.few_shot: few_shot_generate(args) else: diff --git a/llms/mistral/convert.py b/llms/mistral/convert.py index cd3601e5..4c5c7ef1 100644 --- a/llms/mistral/convert.py +++ b/llms/mistral/convert.py @@ -42,7 +42,7 @@ if __name__ == "__main__": "--torch-path", type=str, default="mistral-7B-v0.1/", - help="The path to the PyTorch Mistral model.", + help="The path to the PyTorch model.", ) parser.add_argument( "--mlx-path", @@ -53,7 +53,7 @@ if __name__ == "__main__": parser.add_argument( "-q", "--quantize", - help="Generate a 4-bit quantized model.", + help="Generate a quantized model.", action="store_true", ) parser.add_argument( @@ -84,9 +84,6 @@ if __name__ == "__main__": weights, config = quantize(weights, config, args) # Save weights - import pdb - - pdb.set_trace() np.savez(str(mlx_path / "weights.npz"), **weights) # Copy tokenizer