args for quantization

This commit is contained in:
Awni Hannun
2023-12-20 21:08:03 -08:00
parent 89db6ffdfe
commit aced530649
3 changed files with 60 additions and 23 deletions

View File

@@ -23,16 +23,17 @@ tar -xf mistral-7B-v0.1.tar
Then, convert the weights with: Then, convert the weights with:
``` ```
python convert.py python convert.py --torch-path <path_to_torch>
``` ```
To generate a 4-bit quantized model, use: To generate a 4-bit quantized model, use ``-q``. For a full list of options:
``` ```
python convert.py -q python convert.py --help
``` ```
The conversion script will save the converted weights in the same location. By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz`, `tokenizer.model`, and `config.json` there.
> [!TIP] > [!TIP]
> Alternatively, you can also download a few converted checkpoints from the > Alternatively, you can also download a few converted checkpoints from the
@@ -46,7 +47,7 @@ Once you've converted the weights to MLX format, you can generate text with
the Mistral model: the Mistral model:
``` ```
python mistral.py --prompt "It is a truth universally acknowledged," --temp 0 python mistral.py --prompt "It is a truth universally acknowledged,"
``` ```
Run `python mistral.py --help` for more details. Run `python mistral.py --help` for more details.

View File

@@ -1,7 +1,9 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import copy
import json import json
import shutil
from pathlib import Path from pathlib import Path
import mlx.core as mx import mlx.core as mx
@@ -12,7 +14,7 @@ from mistral import Mistral, ModelArgs
from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
def quantize(weights, config): def quantize(weights, config, args):
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
# Load the model: # Load the model:
@@ -22,10 +24,13 @@ def quantize(weights, config):
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module(model) nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config: # Update the config:
quantized_config["quantization"] = {"group_size": 64, "bits": 4} quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters())) quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config return quantized_weights, quantized_config
@@ -34,10 +39,16 @@ def quantize(weights, config):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument( parser.add_argument(
"--model-path", "--torch-path",
type=str, type=str,
default="mistral-7B-v0.1/", default="mistral-7B-v0.1/",
help="The path to the Mistral model. The MLX weights will also be saved there.", help="The path to the PyTorch Mistral model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
) )
parser.add_argument( parser.add_argument(
"-q", "-q",
@@ -45,20 +56,46 @@ if __name__ == "__main__":
help="Generate a 4-bit quantized model.", help="Generate a 4-bit quantized model.",
action="store_true", action="store_true",
) )
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args() args = parser.parse_args()
model_path = Path(args.model_path) torch_path = Path(args.torch_path)
state = torch.load(str(model_path / "consolidated.00.pth")) state = torch.load(str(torch_path / "consolidated.00.pth"))
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
weights = {k: v.to(torch.float16).numpy() for k, v in state.items()} weights = {k: v.to(torch.float16).numpy() for k, v in state.items()}
with open(torch_path / "params.json", "r") as f:
config = json.loads(f.read())
if args.quantize: if args.quantize:
print("[INFO] Quantizing") print("[INFO] Quantizing")
weights, params = quantize(weights, params) weights, config = quantize(weights, config, args)
np.savez(str(model_path / "weights.npz"), **weights) # Save weights
import pdb
pdb.set_trace()
np.savez(str(mlx_path / "weights.npz"), **weights)
# Copy tokenizer
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
# Save config.json with model_type # Save config.json with model_type
with open(model_path / "params.json", "r") as f: with open(mlx_path / "config.json", "w") as f:
config = json.loads(f.read())
config["model_type"] = "mistral" config["model_type"] = "mistral"
with open(model_path / "config.json", "w") as f:
json.dump(config, f, indent=4) json.dump(config, f, indent=4)

View File

@@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten from mlx.utils import tree_unflatten
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
@@ -189,7 +189,7 @@ class Tokenizer:
return out return out
def load_model(folder: str, dtype=mx.float16): def load_model(folder: str):
model_path = Path(folder) model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model")) tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f: with open(model_path / "config.json", "r") as f:
@@ -200,7 +200,6 @@ def load_model(folder: str, dtype=mx.float16):
model_args = ModelArgs(**config) model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz")) weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mistral(model_args) model = Mistral(model_args)
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.QuantizedLinear.quantize_module(model, **quantization)
@@ -230,7 +229,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model-path", "--model-path",
type=str, type=str,
default="mistral-7B-v0.1", default="mlx_model",
help="The path to the model weights and tokenizer", help="The path to the model weights and tokenizer",
) )
parser.add_argument( parser.add_argument(
@@ -239,7 +238,7 @@ if __name__ == "__main__":
default="In the beginning the Universe was created.", default="In the beginning the Universe was created.",
) )
parser.add_argument( parser.add_argument(
"--max_tokens", "--max-tokens",
"-m", "-m",
type=int, type=int,
default=100, default=100,
@@ -249,7 +248,7 @@ if __name__ == "__main__":
"--temp", "--temp",
help="The sampling temperature.", help="The sampling temperature.",
type=float, type=float,
default=1.0, default=0.0,
) )
parser.add_argument( parser.add_argument(
"--tokens_per_eval", "--tokens_per_eval",