Add config.json to Mixtral.

This commit is contained in:
Vaibhav Srivastav
2023-12-20 22:32:23 +05:30
parent 730c50d00a
commit f881821895
2 changed files with 5 additions and 1 deletions

View File

@@ -43,6 +43,9 @@ if __name__ == "__main__":
with open("params.json") as fid:
args = json.load(fid)
args["model_type"] = "mixtral"
with open(model_path / "config.json", "w") as f:
json.dump(args, f, indent=4)
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))

View File

@@ -247,8 +247,9 @@ class Tokenizer:
def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open("params.json", "r") as f:
with open("config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
model_args = ModelArgs(**config)
weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}