From aed14618cab52a181224eb4f57e6a794e500e6cc Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Wed, 20 Dec 2023 23:17:23 +0530 Subject: [PATCH] Add config.json to Mixtral. (#158) * Add config.json to Mixtral. * Update mixtral/mixtral.py Co-authored-by: Pedro Cuenca --------- Co-authored-by: Pedro Cuenca --- mixtral/convert.py | 3 +++ mixtral/mixtral.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mixtral/convert.py b/mixtral/convert.py index d6ba8030..e89433a5 100644 --- a/mixtral/convert.py +++ b/mixtral/convert.py @@ -43,6 +43,9 @@ if __name__ == "__main__": with open("params.json") as fid: args = json.load(fid) + args["model_type"] = "mixtral" + with open(model_path / "config.json", "w") as f: + json.dump(args, f, indent=4) torch_files = glob.glob(str(model_path / "consolidated.*.pt")) torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2])) diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py index 16a9eec8..2e82d305 100644 --- a/mixtral/mixtral.py +++ b/mixtral/mixtral.py @@ -247,8 +247,9 @@ class Tokenizer: def load_model(folder: str, dtype=mx.float16): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open("params.json", "r") as f: + with open(model_path / "config.json", "r") as f: config = json.loads(f.read()) + config.pop("model_type", None) model_args = ModelArgs(**config) weight_files = glob.glob(str(model_path / "weights.*.npz")) weights = {}