Whisper updates to allow HF models (#923)

* simplify conversion and update convert for HF models

* use npz for compat

* fixes

* fixes

* fix gguf

* allow user supplied path
This commit is contained in:
Awni Hannun 2024-08-09 11:11:58 -07:00 committed by GitHub
parent df744c98e6
commit 33905447f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 116 additions and 75 deletions

View File

@ -32,6 +32,7 @@ Some more useful examples are listed below.
- Joint text and image embeddings with [CLIP](clip).
- Text generation from image and text inputs with [LLaVA](llava).
- Image segmentation with [Segment Anything (SAM)](segment_anything).
### Other Models

View File

@ -59,7 +59,7 @@ class HfVocab:
for token_id in range(self.vocab_size_base):
if token_id in self.added_tokens_ids:
continue
token_text = reverse_vocab[token_id].encode("utf-8")
token_text = reverse_vocab[token_id]
yield token_text, self.get_token_score(token_id), self.get_token_type(
token_id, token_text, self.special_ids
)
@ -67,7 +67,7 @@ class HfVocab:
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text.encode("utf-8")):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
@ -84,7 +84,7 @@ class HfVocab:
else:
toktype = TokenType.USER_DEFINED
score = -1000.0
yield text.encode("utf-8"), score, toktype
yield text, score, toktype
def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab

View File

@ -105,6 +105,88 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def hf_to_pt(weights, config):
config = {
"n_mels": config["num_mel_bins"],
"n_audio_ctx": config["max_source_positions"],
"n_audio_state": config["d_model"],
"n_audio_head": config["encoder_attention_heads"],
"n_audio_layer": config["encoder_layers"],
"n_vocab": config["vocab_size"],
"n_text_ctx": config["max_target_positions"],
"n_text_state": config["d_model"],
"n_text_head": config["decoder_attention_heads"],
"n_text_layer": config["decoder_layers"],
}
def remap(k):
k = k.replace("model.", "")
k = k.replace(".layers", ".blocks")
k = k.replace(".self_attn", ".attn")
k = k.replace(".attn_layer_norm", ".attn_ln")
k = k.replace(".encoder_attn.", ".cross_attn.")
k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln")
k = k.replace(".final_layer_norm", ".mlp_ln")
k = k.replace(".q_proj", ".query")
k = k.replace(".k_proj", ".key")
k = k.replace(".v_proj", ".value")
k = k.replace(".out_proj", ".out")
k = k.replace(".fc1", ".mlp1")
k = k.replace(".fc2", ".mlp2")
k = k.replace("embed_positions.weight", "positional_embedding")
k = k.replace("decoder.embed_tokens", "decoder.token_embedding")
k = k.replace("encoder.layer_norm", "encoder.ln_post")
k = k.replace("decoder.layer_norm", "decoder.ln")
return k
# token embeddings are shared with output projection
weights.pop("proj_out.weight", None)
weights = {remap(k): v for k, v in weights.items()}
return weights, config
def load_torch_weights_and_config(
name_or_path: str,
download_root: str = None,
):
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
# todo: accept alignment_heads of local Pytorch checkpoint
alignment_heads = None
if name_or_path in _MODELS:
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).exists():
# Try downloading from HF
from huggingface_hub import snapshot_download
name_or_path = snapshot_download(
repo_id=name_or_path,
allow_patterns=["*.json", "pytorch_model.bin", "*.txt"],
)
else:
raise RuntimeError(
f"Model {name_or_path} is not found in {available_models()},"
"on Hugging Face or as a local path."
)
if name_or_path.endswith(".pt"):
checkpoint = torch.load(name_or_path, map_location="cpu")
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
else:
name_or_path = Path(name_or_path)
weights = torch.load(
name_or_path / "pytorch_model.bin",
map_location="cpu",
)
with open(name_or_path / "config.json", "r") as fp:
config = json.load(fp)
weights, config = hf_to_pt(weights, config)
return weights, config, alignment_heads
def load_torch_model(
name_or_path: str,
download_root: str = None,
@ -115,7 +197,8 @@ def load_torch_model(
Parameters
----------
name_or_path : str
one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format
one of the official model names listed by `whisper.available_models()` or
a local Pytorch checkpoint which is in the original OpenAI format
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
@ -128,22 +211,12 @@ def load_torch_model(
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
# todo: accept alignment_heads of local Pytorch checkpoint
alignment_heads = None
if name_or_path in _MODELS:
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).is_file():
raise RuntimeError(
f"Model {name_or_path} is neither found in {available_models()} nor as a local path"
)
with open(name_or_path, "rb") as fp:
checkpoint = torch.load(fp)
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
weights, config, alignment_heads = load_torch_weights_and_config(
name_or_path, download_root
)
dims = torch_whisper.ModelDimensions(**config)
model = torch_whisper.Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
model.load_state_dict(weights)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
@ -151,59 +224,26 @@ def load_torch_model(
return model
def convert(model, rules=None):
params = {}
if rules is not None and type(model) in rules:
out = rules[type(model)](model, rules)
return out
if isinstance(model, torch.Tensor):
return mx.array(model.detach().numpy())
if isinstance(model, torch.nn.ModuleList):
return [convert(n, rules) for n in model.children()]
if isinstance(model, torch.nn.Conv1d):
return {
"weight": convert(model.weight).transpose(0, 2, 1),
"bias": convert(model.bias),
}
for k, n in model.named_children():
if k in rules:
params.update(rules[k](n, rules))
else:
params[k] = convert(n, rules)
for k, p in model.named_parameters(recurse=False):
params[k] = convert(p)
return params
def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
def remap(key, value):
key = key.replace("mlp.0", "mlp1")
key = key.replace("mlp.2", "mlp2")
if "conv" in key and value.ndim == 3:
value = value.swapaxes(1, 2)
return key, mx.array(value.detach()).astype(dtype)
weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
weights.pop("encoder.positional_embedding", None)
weights = dict(remap(k, v) for k, v in weights.items())
def torch_to_mlx(
torch_model: torch_whisper.Whisper,
dtype: mx.Dtype = mx.float16,
) -> Whisper:
def convert_rblock(model, rules):
children = dict(model.named_children())
mlp = list(children.pop("mlp").children())
params = {
"mlp1": convert(mlp[0], rules),
"mlp2": convert(mlp[-1], rules),
}
for k, n in children.items():
params[k] = convert(n, rules)
return params
model_dims = ModelDimensions(**config)
model = Whisper(model_dims, dtype)
model.load_weights(list(weights.items()), strict=False)
rules = {
torch_whisper.ResidualAttentionBlock: convert_rblock,
}
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
params = convert(torch_model, rules)
mlx_model = Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
return mlx_model
return model
def upload_to_hub(path: str, name: str, torch_name_or_path: str):
@ -292,13 +332,13 @@ if __name__ == "__main__":
action="store_true",
)
parser.add_argument(
"--q_group_size",
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
@ -318,7 +358,7 @@ if __name__ == "__main__":
dtype = getattr(mx, args.dtype)
print("[INFO] Loading")
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
model = convert(args.torch_name_or_path, dtype)
config = asdict(model.dims)
weights = dict(tree_flatten(model.parameters()))

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.1.0"
__version__ = "0.2.0"

View File

@ -13,7 +13,7 @@ import mlx_whisper.decoding as decoding
import mlx_whisper.load_models as load_models
import numpy as np
import torch
from convert import load_torch_model, quantize, torch_to_mlx
from convert import convert, load_torch_model, quantize
from mlx.utils import tree_flatten
MODEL_NAME = "tiny"
@ -41,12 +41,12 @@ def _save_model(save_dir, weights, config):
def load_torch_and_mlx():
torch_model = load_torch_model(MODEL_NAME)
fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
fp32_model = convert(MODEL_NAME, dtype=mx.float32)
config = asdict(fp32_model.dims)
weights = dict(tree_flatten(fp32_model.parameters()))
_save_model(MLX_FP32_MODEL_PATH, weights, config)
fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
fp16_model = convert(MODEL_NAME, dtype=mx.float16)
config = asdict(fp16_model.dims)
weights = dict(tree_flatten(fp16_model.parameters()))
_save_model(MLX_FP16_MODEL_PATH, weights, config)