From 81183c309151f24bc11593a8227a506132e0c2d8 Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Fri, 29 Dec 2023 11:19:05 +0100 Subject: [PATCH] Apply reviews --- whisper/convert.py | 14 +++++++------- whisper/whisper/__init__.py | 2 +- whisper/whisper/load_models.py | 29 ++++++++--------------------- 3 files changed, 16 insertions(+), 29 deletions(-) diff --git a/whisper/convert.py b/whisper/convert.py index 0d4e76d0..d0897b50 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -12,16 +12,16 @@ import numpy as np from mlx.utils import tree_flatten, tree_map, tree_unflatten from whisper.load_models import load_torch_model, torch_to_mlx -from whisper.torch_whisper import ModelDimensions -from whisper.whisper import Whisper +from whisper.whisper import ModelDimensions, Whisper MODEL_DTYPES = {"float16", "float32"} -def quantize(weights, config, dtype, args): + +def quantize(weights, config, args): quantized_config = copy.deepcopy(config) # Load the model: - model = Whisper(ModelDimensions(**config), dtype) + model = Whisper(ModelDimensions(**config)) weights = tree_map(mx.array, weights) model.update(tree_unflatten(list(weights.items()))) @@ -39,7 +39,7 @@ def quantize(weights, config, dtype, args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") + parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.") parser.add_argument( "--torch-name-or-path", type=str, @@ -88,7 +88,7 @@ if __name__ == "__main__": if args.quantize: print("[INFO] Quantizing") - weights, config = quantize(weights, config, dtype, args) + weights, config = quantize(weights, config, args) mlx_path = Path(args.mlx_path) mlx_path.mkdir(parents=True, exist_ok=True) @@ -98,6 +98,6 @@ if __name__ == "__main__": np.savez(str(mlx_path / "weights.npz"), **weights) # Save config.json with model_type - with open(mlx_path / "config.json", "w") as f: + with open(str(mlx_path / "config.json"), "w") as f: config["model_type"] = "whisper" json.dump(config, f, indent=4) diff --git a/whisper/whisper/__init__.py b/whisper/whisper/__init__.py index f5b96966..e234711c 100644 --- a/whisper/whisper/__init__.py +++ b/whisper/whisper/__init__.py @@ -1,4 +1,4 @@ # Copyright © 2023 Apple Inc. -from . import audio, decoding, load_models, torch_whisper, whisper +from . import audio, decoding, load_models from .transcribe import transcribe diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 283754cb..05adbe49 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -201,42 +201,29 @@ def load_model( if name_or_path in _MODELS: print(f"[INFO] Loading and converting {name_or_path} model") return torch_to_mlx(load_torch_model(name_or_path, download_root), dtype) - elif not (glob.glob(f"{name_or_path}/weights*.npz") and glob.glob(f"{name_or_path}/config.json")): + elif not (glob.glob(f"{name_or_path}/weights.npz") and glob.glob(f"{name_or_path}/config.json")): raise ValueError( - f"{name_or_path} not found in {available_models()}. Ensure that weights*.npz and config.json files are" + f"{name_or_path} not found in {available_models()}. Ensure that weights.npz and config.json files are" " present in the specified path" ) model_path = Path(name_or_path) - unsharded_weights_path = model_path / "weights.npz" - if unsharded_weights_path.is_file(): - print(f"[INFO] Loading model from {unsharded_weights_path}") - weights = mx.load(str(unsharded_weights_path)) - else: - sharded_weights_glob = str(model_path / "weights.*.npz") - weight_files = glob.glob(sharded_weights_glob) - print(f"[INFO] Loading model from {sharded_weights_glob}") - - if len(weight_files) == 0: - raise FileNotFoundError("No weights found in {}".format(model_path)) - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf).items()) - - with open(model_path / "config.json", "r") as f: + with open(str(model_path / "config.json"), "r") as f: config = json.loads(f.read()) config.pop("model_type", None) quantization = config.pop("quantization", None) - model_args = torch_whisper.ModelDimensions(**config) + model_args = whisper.ModelDimensions(**config) + + weights = mx.load(str(model_path / "weights.npz")) + weights = tree_unflatten(list(weights.items())) + model = whisper.Whisper(model_args, dtype) if quantization is not None: nn.QuantizedLinear.quantize_module(model, **quantization) - weights = tree_unflatten(list(weights.items())) model.update(weights) mx.eval(model.parameters())