mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Whisper updates to allow HF models (#923)
* simplify conversion and update convert for HF models * use npz for compat * fixes * fixes * fix gguf * allow user supplied path
This commit is contained in:
parent
df744c98e6
commit
33905447f9
@ -32,6 +32,7 @@ Some more useful examples are listed below.
|
||||
|
||||
- Joint text and image embeddings with [CLIP](clip).
|
||||
- Text generation from image and text inputs with [LLaVA](llava).
|
||||
- Image segmentation with [Segment Anything (SAM)](segment_anything).
|
||||
|
||||
### Other Models
|
||||
|
||||
|
@ -59,7 +59,7 @@ class HfVocab:
|
||||
for token_id in range(self.vocab_size_base):
|
||||
if token_id in self.added_tokens_ids:
|
||||
continue
|
||||
token_text = reverse_vocab[token_id].encode("utf-8")
|
||||
token_text = reverse_vocab[token_id]
|
||||
yield token_text, self.get_token_score(token_id), self.get_token_type(
|
||||
token_id, token_text, self.special_ids
|
||||
)
|
||||
@ -67,7 +67,7 @@ class HfVocab:
|
||||
def get_token_type(
|
||||
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
||||
) -> TokenType:
|
||||
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text.encode("utf-8")):
|
||||
return TokenType.BYTE
|
||||
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
|
||||
|
||||
@ -84,7 +84,7 @@ class HfVocab:
|
||||
else:
|
||||
toktype = TokenType.USER_DEFINED
|
||||
score = -1000.0
|
||||
yield text.encode("utf-8"), score, toktype
|
||||
yield text, score, toktype
|
||||
|
||||
def has_newline_token(self):
|
||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||
|
@ -105,6 +105,88 @@ def available_models() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def hf_to_pt(weights, config):
|
||||
config = {
|
||||
"n_mels": config["num_mel_bins"],
|
||||
"n_audio_ctx": config["max_source_positions"],
|
||||
"n_audio_state": config["d_model"],
|
||||
"n_audio_head": config["encoder_attention_heads"],
|
||||
"n_audio_layer": config["encoder_layers"],
|
||||
"n_vocab": config["vocab_size"],
|
||||
"n_text_ctx": config["max_target_positions"],
|
||||
"n_text_state": config["d_model"],
|
||||
"n_text_head": config["decoder_attention_heads"],
|
||||
"n_text_layer": config["decoder_layers"],
|
||||
}
|
||||
|
||||
def remap(k):
|
||||
k = k.replace("model.", "")
|
||||
k = k.replace(".layers", ".blocks")
|
||||
k = k.replace(".self_attn", ".attn")
|
||||
k = k.replace(".attn_layer_norm", ".attn_ln")
|
||||
k = k.replace(".encoder_attn.", ".cross_attn.")
|
||||
k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln")
|
||||
k = k.replace(".final_layer_norm", ".mlp_ln")
|
||||
k = k.replace(".q_proj", ".query")
|
||||
k = k.replace(".k_proj", ".key")
|
||||
k = k.replace(".v_proj", ".value")
|
||||
k = k.replace(".out_proj", ".out")
|
||||
k = k.replace(".fc1", ".mlp1")
|
||||
k = k.replace(".fc2", ".mlp2")
|
||||
k = k.replace("embed_positions.weight", "positional_embedding")
|
||||
k = k.replace("decoder.embed_tokens", "decoder.token_embedding")
|
||||
k = k.replace("encoder.layer_norm", "encoder.ln_post")
|
||||
k = k.replace("decoder.layer_norm", "decoder.ln")
|
||||
return k
|
||||
|
||||
# token embeddings are shared with output projection
|
||||
weights.pop("proj_out.weight", None)
|
||||
weights = {remap(k): v for k, v in weights.items()}
|
||||
return weights, config
|
||||
|
||||
|
||||
def load_torch_weights_and_config(
|
||||
name_or_path: str,
|
||||
download_root: str = None,
|
||||
):
|
||||
if download_root is None:
|
||||
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
||||
|
||||
# todo: accept alignment_heads of local Pytorch checkpoint
|
||||
alignment_heads = None
|
||||
if name_or_path in _MODELS:
|
||||
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
|
||||
name_or_path = _download(_MODELS[name_or_path], download_root)
|
||||
elif not Path(name_or_path).exists():
|
||||
# Try downloading from HF
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
name_or_path = snapshot_download(
|
||||
repo_id=name_or_path,
|
||||
allow_patterns=["*.json", "pytorch_model.bin", "*.txt"],
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name_or_path} is not found in {available_models()},"
|
||||
"on Hugging Face or as a local path."
|
||||
)
|
||||
|
||||
if name_or_path.endswith(".pt"):
|
||||
checkpoint = torch.load(name_or_path, map_location="cpu")
|
||||
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
|
||||
else:
|
||||
name_or_path = Path(name_or_path)
|
||||
weights = torch.load(
|
||||
name_or_path / "pytorch_model.bin",
|
||||
map_location="cpu",
|
||||
)
|
||||
with open(name_or_path / "config.json", "r") as fp:
|
||||
config = json.load(fp)
|
||||
weights, config = hf_to_pt(weights, config)
|
||||
|
||||
return weights, config, alignment_heads
|
||||
|
||||
|
||||
def load_torch_model(
|
||||
name_or_path: str,
|
||||
download_root: str = None,
|
||||
@ -115,7 +197,8 @@ def load_torch_model(
|
||||
Parameters
|
||||
----------
|
||||
name_or_path : str
|
||||
one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format
|
||||
one of the official model names listed by `whisper.available_models()` or
|
||||
a local Pytorch checkpoint which is in the original OpenAI format
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
|
||||
@ -128,22 +211,12 @@ def load_torch_model(
|
||||
if download_root is None:
|
||||
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
||||
|
||||
# todo: accept alignment_heads of local Pytorch checkpoint
|
||||
alignment_heads = None
|
||||
if name_or_path in _MODELS:
|
||||
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
|
||||
name_or_path = _download(_MODELS[name_or_path], download_root)
|
||||
elif not Path(name_or_path).is_file():
|
||||
raise RuntimeError(
|
||||
f"Model {name_or_path} is neither found in {available_models()} nor as a local path"
|
||||
)
|
||||
|
||||
with open(name_or_path, "rb") as fp:
|
||||
checkpoint = torch.load(fp)
|
||||
|
||||
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
|
||||
weights, config, alignment_heads = load_torch_weights_and_config(
|
||||
name_or_path, download_root
|
||||
)
|
||||
dims = torch_whisper.ModelDimensions(**config)
|
||||
model = torch_whisper.Whisper(dims)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.load_state_dict(weights)
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
@ -151,59 +224,26 @@ def load_torch_model(
|
||||
return model
|
||||
|
||||
|
||||
def convert(model, rules=None):
|
||||
params = {}
|
||||
if rules is not None and type(model) in rules:
|
||||
out = rules[type(model)](model, rules)
|
||||
return out
|
||||
if isinstance(model, torch.Tensor):
|
||||
return mx.array(model.detach().numpy())
|
||||
if isinstance(model, torch.nn.ModuleList):
|
||||
return [convert(n, rules) for n in model.children()]
|
||||
if isinstance(model, torch.nn.Conv1d):
|
||||
return {
|
||||
"weight": convert(model.weight).transpose(0, 2, 1),
|
||||
"bias": convert(model.bias),
|
||||
}
|
||||
for k, n in model.named_children():
|
||||
if k in rules:
|
||||
params.update(rules[k](n, rules))
|
||||
else:
|
||||
params[k] = convert(n, rules)
|
||||
for k, p in model.named_parameters(recurse=False):
|
||||
params[k] = convert(p)
|
||||
return params
|
||||
def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
|
||||
def remap(key, value):
|
||||
key = key.replace("mlp.0", "mlp1")
|
||||
key = key.replace("mlp.2", "mlp2")
|
||||
if "conv" in key and value.ndim == 3:
|
||||
value = value.swapaxes(1, 2)
|
||||
return key, mx.array(value.detach()).astype(dtype)
|
||||
|
||||
weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
|
||||
weights.pop("encoder.positional_embedding", None)
|
||||
weights = dict(remap(k, v) for k, v in weights.items())
|
||||
|
||||
def torch_to_mlx(
|
||||
torch_model: torch_whisper.Whisper,
|
||||
dtype: mx.Dtype = mx.float16,
|
||||
) -> Whisper:
|
||||
def convert_rblock(model, rules):
|
||||
children = dict(model.named_children())
|
||||
mlp = list(children.pop("mlp").children())
|
||||
params = {
|
||||
"mlp1": convert(mlp[0], rules),
|
||||
"mlp2": convert(mlp[-1], rules),
|
||||
}
|
||||
for k, n in children.items():
|
||||
params[k] = convert(n, rules)
|
||||
return params
|
||||
model_dims = ModelDimensions(**config)
|
||||
model = Whisper(model_dims, dtype)
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
|
||||
rules = {
|
||||
torch_whisper.ResidualAttentionBlock: convert_rblock,
|
||||
}
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
params = convert(torch_model, rules)
|
||||
|
||||
mlx_model = Whisper(torch_model.dims, dtype)
|
||||
params = tree_map(lambda p: p.astype(dtype), params)
|
||||
mlx_model.update(params)
|
||||
|
||||
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
|
||||
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
|
||||
|
||||
return mlx_model
|
||||
return model
|
||||
|
||||
|
||||
def upload_to_hub(path: str, name: str, torch_name_or_path: str):
|
||||
@ -292,13 +332,13 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_group_size",
|
||||
"--q-group-size",
|
||||
help="Group size for quantization.",
|
||||
type=int,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_bits",
|
||||
"--q-bits",
|
||||
help="Bits per weight for quantization.",
|
||||
type=int,
|
||||
default=4,
|
||||
@ -318,7 +358,7 @@ if __name__ == "__main__":
|
||||
dtype = getattr(mx, args.dtype)
|
||||
|
||||
print("[INFO] Loading")
|
||||
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
|
||||
model = convert(args.torch_name_or_path, dtype)
|
||||
config = asdict(model.dims)
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.2.0"
|
||||
|
@ -13,7 +13,7 @@ import mlx_whisper.decoding as decoding
|
||||
import mlx_whisper.load_models as load_models
|
||||
import numpy as np
|
||||
import torch
|
||||
from convert import load_torch_model, quantize, torch_to_mlx
|
||||
from convert import convert, load_torch_model, quantize
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
MODEL_NAME = "tiny"
|
||||
@ -41,12 +41,12 @@ def _save_model(save_dir, weights, config):
|
||||
def load_torch_and_mlx():
|
||||
torch_model = load_torch_model(MODEL_NAME)
|
||||
|
||||
fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
|
||||
fp32_model = convert(MODEL_NAME, dtype=mx.float32)
|
||||
config = asdict(fp32_model.dims)
|
||||
weights = dict(tree_flatten(fp32_model.parameters()))
|
||||
_save_model(MLX_FP32_MODEL_PATH, weights, config)
|
||||
|
||||
fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
|
||||
fp16_model = convert(MODEL_NAME, dtype=mx.float16)
|
||||
config = asdict(fp16_model.dims)
|
||||
weights = dict(tree_flatten(fp16_model.parameters()))
|
||||
_save_model(MLX_FP16_MODEL_PATH, weights, config)
|
||||
|
Loading…
Reference in New Issue
Block a user