diff --git a/mistral/mistral.py b/mistral/mistral.py index 0c3976c1..11dcface 100644 --- a/mistral/mistral.py +++ b/mistral/mistral.py @@ -192,9 +192,10 @@ class Tokenizer: def load_model(folder: str, dtype=mx.float16): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open(model_path / "params.json", "r") as f: + with open(model_path / "config.json", "r") as f: config = json.loads(f.read()) - config.pop("sliding_window") + config.pop("sliding_window", None) + config.pop("model_type", None) model_args = ModelArgs(**config) weights = mx.load(str(model_path / "weights.npz")) weights = tree_unflatten(list(weights.items()))