Fix conversion + inference errors.

This commit is contained in:
Vaibhav Srivastav
2023-12-23 01:19:58 +05:30
parent 7ae445f6c7
commit adbd188c57
2 changed files with 2 additions and 0 deletions

View File

@@ -19,6 +19,7 @@ def quantize(weights, config, args):
# Load the model:
config.pop("sliding_window", None)
config.pop("rope_theta", None)
model = Mistral(ModelArgs(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))

View File

@@ -196,6 +196,7 @@ def load_model(folder: str):
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("sliding_window", None)
config.pop("rope_theta", None)
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)