mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-05 16:34:34 +08:00
use official HF for mixtral
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
@@ -222,10 +223,13 @@ class Tokenizer:
|
||||
def load_model(folder: str, dtype=mx.float16):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
with open("params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
model_args = ModelArgs(**config)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weight_files = glob.glob(str(model_path / "weights.*.npz"))
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mixtral(model_args)
|
||||
@@ -255,7 +259,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mixtral-8x7b-32kseqlen",
|
||||
default="Mixtral-8x7B-v0.1",
|
||||
help="The path to the model weights, tokenizer, and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
Reference in New Issue
Block a user