mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Quantize example (#162)
* testing quantization * conversion + quantization working * one config processor * quantization in mistral / nits in llama * args for quantization * llama / mistral conversion in good shape * phi2 quantized * mixtral * qwen conversion
This commit is contained in:
@@ -244,20 +244,27 @@ class Tokenizer:
|
||||
return out
|
||||
|
||||
|
||||
def load_model(folder: str, dtype=mx.float16):
|
||||
def load_model(folder: str):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
quantization = config.pop("quantization", None)
|
||||
model_args = ModelArgs(**config)
|
||||
weight_files = glob.glob(str(model_path / "weights.*.npz"))
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mixtral(model_args)
|
||||
if quantization is not None:
|
||||
# TODO: Quantize gate matrices when < 32 tiles supported
|
||||
quantization["linear_class_predicate"] = (
|
||||
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
|
||||
)
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
||||
@@ -284,7 +291,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="Mixtral-8x7B-v0.1",
|
||||
default="mlx_model",
|
||||
help="The path to the model weights, tokenizer, and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
Reference in New Issue
Block a user