From 730c50d00ab44d3851b3d6cc14aca8aa43382319 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 20 Dec 2023 17:39:37 +0100 Subject: [PATCH] Use config.json, add model_type (#157) * Use config.json, add model_type * Update convert to generate config.json --- mistral/convert.py | 8 ++++++++ mistral/mistral.py | 5 +++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mistral/convert.py b/mistral/convert.py index 0efaf489..f7bcfd86 100644 --- a/mistral/convert.py +++ b/mistral/convert.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import argparse +import json import numpy as np from pathlib import Path import torch @@ -22,3 +23,10 @@ if __name__ == "__main__": str(model_path / "weights.npz"), **{k: v.to(torch.float16).numpy() for k, v in state.items()} ) + + # Save config.json with model_type + with open(model_path / "params.json", "r") as f: + config = json.loads(f.read()) + config["model_type"] = "mistral" + with open(model_path / "config.json", "w") as f: + json.dump(config, f, indent=4) diff --git a/mistral/mistral.py b/mistral/mistral.py index 0c3976c1..11dcface 100644 --- a/mistral/mistral.py +++ b/mistral/mistral.py @@ -192,9 +192,10 @@ class Tokenizer: def load_model(folder: str, dtype=mx.float16): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open(model_path / "params.json", "r") as f: + with open(model_path / "config.json", "r") as f: config = json.loads(f.read()) - config.pop("sliding_window") + config.pop("sliding_window", None) + config.pop("model_type", None) model_args = ModelArgs(**config) weights = mx.load(str(model_path / "weights.npz")) weights = tree_unflatten(list(weights.items()))