mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Add config.json to Mixtral. (#158)
* Add config.json to Mixtral. * Update mixtral/mixtral.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
730c50d00a
commit
aed14618ca
@ -43,6 +43,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
with open("params.json") as fid:
|
with open("params.json") as fid:
|
||||||
args = json.load(fid)
|
args = json.load(fid)
|
||||||
|
args["model_type"] = "mixtral"
|
||||||
|
with open(model_path / "config.json", "w") as f:
|
||||||
|
json.dump(args, f, indent=4)
|
||||||
|
|
||||||
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
|
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
|
||||||
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
|
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
|
||||||
|
@ -247,8 +247,9 @@ class Tokenizer:
|
|||||||
def load_model(folder: str, dtype=mx.float16):
|
def load_model(folder: str, dtype=mx.float16):
|
||||||
model_path = Path(folder)
|
model_path = Path(folder)
|
||||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||||
with open("params.json", "r") as f:
|
with open(model_path / "config.json", "r") as f:
|
||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
|
config.pop("model_type", None)
|
||||||
model_args = ModelArgs(**config)
|
model_args = ModelArgs(**config)
|
||||||
weight_files = glob.glob(str(model_path / "weights.*.npz"))
|
weight_files = glob.glob(str(model_path / "weights.*.npz"))
|
||||||
weights = {}
|
weights = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user