Add config.json to Mixtral. (#158)

* Add config.json to Mixtral.

* Update mixtral/mixtral.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Vaibhav Srivastav 2023-12-20 23:17:23 +05:30 committed by GitHub
parent 730c50d00a
commit aed14618ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 1 deletions

View File

@ -43,6 +43,9 @@ if __name__ == "__main__":
with open("params.json") as fid:
args = json.load(fid)
args["model_type"] = "mixtral"
with open(model_path / "config.json", "w") as f:
json.dump(args, f, indent=4)
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))

View File

@ -247,8 +247,9 @@ class Tokenizer:
def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open("params.json", "r") as f:
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
model_args = ModelArgs(**config)
weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}