Use config.json, add model_type

This commit is contained in:
Pedro Cuenca 2023-12-20 17:05:27 +01:00
parent b6e62caf2e
commit 9318d99b9c

View File

@ -192,9 +192,10 @@ class Tokenizer:
def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("sliding_window")
config.pop("sliding_window", None)
config.pop("model_type", None)
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))