mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Apply reviews
This commit is contained in:
parent
39600eb383
commit
81183c3091
@ -12,16 +12,16 @@ import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from whisper.load_models import load_torch_model, torch_to_mlx
|
||||
from whisper.torch_whisper import ModelDimensions
|
||||
from whisper.whisper import Whisper
|
||||
from whisper.whisper import ModelDimensions, Whisper
|
||||
|
||||
MODEL_DTYPES = {"float16", "float32"}
|
||||
|
||||
def quantize(weights, config, dtype, args):
|
||||
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Load the model:
|
||||
model = Whisper(ModelDimensions(**config), dtype)
|
||||
model = Whisper(ModelDimensions(**config))
|
||||
weights = tree_map(mx.array, weights)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
@ -39,7 +39,7 @@ def quantize(weights, config, dtype, args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
||||
parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--torch-name-or-path",
|
||||
type=str,
|
||||
@ -88,7 +88,7 @@ if __name__ == "__main__":
|
||||
|
||||
if args.quantize:
|
||||
print("[INFO] Quantizing")
|
||||
weights, config = quantize(weights, config, dtype, args)
|
||||
weights, config = quantize(weights, config, args)
|
||||
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
@ -98,6 +98,6 @@ if __name__ == "__main__":
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
|
||||
# Save config.json with model_type
|
||||
with open(mlx_path / "config.json", "w") as f:
|
||||
with open(str(mlx_path / "config.json"), "w") as f:
|
||||
config["model_type"] = "whisper"
|
||||
json.dump(config, f, indent=4)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from . import audio, decoding, load_models, torch_whisper, whisper
|
||||
from . import audio, decoding, load_models
|
||||
from .transcribe import transcribe
|
||||
|
@ -201,42 +201,29 @@ def load_model(
|
||||
if name_or_path in _MODELS:
|
||||
print(f"[INFO] Loading and converting {name_or_path} model")
|
||||
return torch_to_mlx(load_torch_model(name_or_path, download_root), dtype)
|
||||
elif not (glob.glob(f"{name_or_path}/weights*.npz") and glob.glob(f"{name_or_path}/config.json")):
|
||||
elif not (glob.glob(f"{name_or_path}/weights.npz") and glob.glob(f"{name_or_path}/config.json")):
|
||||
raise ValueError(
|
||||
f"{name_or_path} not found in {available_models()}. Ensure that weights*.npz and config.json files are"
|
||||
f"{name_or_path} not found in {available_models()}. Ensure that weights.npz and config.json files are"
|
||||
" present in the specified path"
|
||||
)
|
||||
|
||||
model_path = Path(name_or_path)
|
||||
|
||||
unsharded_weights_path = model_path / "weights.npz"
|
||||
if unsharded_weights_path.is_file():
|
||||
print(f"[INFO] Loading model from {unsharded_weights_path}")
|
||||
weights = mx.load(str(unsharded_weights_path))
|
||||
else:
|
||||
sharded_weights_glob = str(model_path / "weights.*.npz")
|
||||
weight_files = glob.glob(sharded_weights_glob)
|
||||
print(f"[INFO] Loading model from {sharded_weights_glob}")
|
||||
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No weights found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
with open(str(model_path / "config.json"), "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
quantization = config.pop("quantization", None)
|
||||
|
||||
model_args = torch_whisper.ModelDimensions(**config)
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
model.update(weights)
|
||||
|
||||
mx.eval(model.parameters())
|
||||
|
Loading…
Reference in New Issue
Block a user