mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add config.json to Mixtral. (#158)
* Add config.json to Mixtral. * Update mixtral/mixtral.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
730c50d00a
commit
aed14618ca
@ -43,6 +43,9 @@ if __name__ == "__main__":
|
||||
|
||||
with open("params.json") as fid:
|
||||
args = json.load(fid)
|
||||
args["model_type"] = "mixtral"
|
||||
with open(model_path / "config.json", "w") as f:
|
||||
json.dump(args, f, indent=4)
|
||||
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
|
||||
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
|
||||
|
@ -247,8 +247,9 @@ class Tokenizer:
|
||||
def load_model(folder: str, dtype=mx.float16):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open("params.json", "r") as f:
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
model_args = ModelArgs(**config)
|
||||
weight_files = glob.glob(str(model_path / "weights.*.npz"))
|
||||
weights = {}
|
||||
|
Loading…
Reference in New Issue
Block a user