mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
parent
574ad7f6fe
commit
6abdbe3be8
@ -285,7 +285,7 @@ def load(gguf_file: str, repo: str = None):
|
|||||||
and f"{p}.scales" in weights
|
and f"{p}.scales" in weights
|
||||||
)
|
)
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
qm,
|
model,
|
||||||
**quantization,
|
**quantization,
|
||||||
class_predicate=class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
@ -27,13 +27,17 @@ def load_model(
|
|||||||
model_args = whisper.ModelDimensions(**config)
|
model_args = whisper.ModelDimensions(**config)
|
||||||
|
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
weights = tree_unflatten(list(weights.items()))
|
|
||||||
|
|
||||||
model = whisper.Whisper(model_args, dtype)
|
model = whisper.Whisper(model_args, dtype)
|
||||||
|
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
nn.quantize(model, **quantization)
|
class_predicate = (
|
||||||
|
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||||
|
and f"{p}.scales" in weights
|
||||||
|
)
|
||||||
|
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||||
|
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
model.update(weights)
|
model.update(weights)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user