mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
parent
574ad7f6fe
commit
6abdbe3be8
@ -285,7 +285,7 @@ def load(gguf_file: str, repo: str = None):
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
nn.quantize(
|
||||
qm,
|
||||
model,
|
||||
**quantization,
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
@ -27,13 +27,17 @@ def load_model(
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
if quantization is not None:
|
||||
nn.quantize(model, **quantization)
|
||||
class_predicate = (
|
||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user