Fix conversion + inference errors.

This commit is contained in:
Vaibhav Srivastav
2023-12-23 01:19:58 +05:30
parent 7ae445f6c7
commit adbd188c57
2 changed files with 2 additions and 0 deletions

View File

@@ -19,6 +19,7 @@ def quantize(weights, config, args):
# Load the model: # Load the model:
config.pop("sliding_window", None) config.pop("sliding_window", None)
config.pop("rope_theta", None)
model = Mistral(ModelArgs(**config)) model = Mistral(ModelArgs(**config))
weights = tree_map(mx.array, weights) weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))

View File

@@ -196,6 +196,7 @@ def load_model(folder: str):
with open(model_path / "config.json", "r") as f: with open(model_path / "config.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
config.pop("sliding_window", None) config.pop("sliding_window", None)
config.pop("rope_theta", None)
config.pop("model_type", None) config.pop("model_type", None)
quantization = config.pop("quantization", None) quantization = config.pop("quantization", None)
model_args = ModelArgs(**config) model_args = ModelArgs(**config)