Apply reviews

This commit is contained in:
bofenghuang 2023-12-29 11:19:05 +01:00
parent 39600eb383
commit 81183c3091
3 changed files with 16 additions and 29 deletions

View File

@ -12,16 +12,16 @@ import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from whisper.load_models import load_torch_model, torch_to_mlx
from whisper.torch_whisper import ModelDimensions
from whisper.whisper import Whisper
from whisper.whisper import ModelDimensions, Whisper
MODEL_DTYPES = {"float16", "float32"}
def quantize(weights, config, dtype, args):
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = Whisper(ModelDimensions(**config), dtype)
model = Whisper(ModelDimensions(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
@ -39,7 +39,7 @@ def quantize(weights, config, dtype, args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.")
parser.add_argument(
"--torch-name-or-path",
type=str,
@ -88,7 +88,7 @@ if __name__ == "__main__":
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, dtype, args)
weights, config = quantize(weights, config, args)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
@ -98,6 +98,6 @@ if __name__ == "__main__":
np.savez(str(mlx_path / "weights.npz"), **weights)
# Save config.json with model_type
with open(mlx_path / "config.json", "w") as f:
with open(str(mlx_path / "config.json"), "w") as f:
config["model_type"] = "whisper"
json.dump(config, f, indent=4)

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
from . import audio, decoding, load_models, torch_whisper, whisper
from . import audio, decoding, load_models
from .transcribe import transcribe

View File

@ -201,42 +201,29 @@ def load_model(
if name_or_path in _MODELS:
print(f"[INFO] Loading and converting {name_or_path} model")
return torch_to_mlx(load_torch_model(name_or_path, download_root), dtype)
elif not (glob.glob(f"{name_or_path}/weights*.npz") and glob.glob(f"{name_or_path}/config.json")):
elif not (glob.glob(f"{name_or_path}/weights.npz") and glob.glob(f"{name_or_path}/config.json")):
raise ValueError(
f"{name_or_path} not found in {available_models()}. Ensure that weights*.npz and config.json files are"
f"{name_or_path} not found in {available_models()}. Ensure that weights.npz and config.json files are"
" present in the specified path"
)
model_path = Path(name_or_path)
unsharded_weights_path = model_path / "weights.npz"
if unsharded_weights_path.is_file():
print(f"[INFO] Loading model from {unsharded_weights_path}")
weights = mx.load(str(unsharded_weights_path))
else:
sharded_weights_glob = str(model_path / "weights.*.npz")
weight_files = glob.glob(sharded_weights_glob)
print(f"[INFO] Loading model from {sharded_weights_glob}")
if len(weight_files) == 0:
raise FileNotFoundError("No weights found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
with open(model_path / "config.json", "r") as f:
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = torch_whisper.ModelDimensions(**config)
model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())