diff --git a/README.md b/README.md index 2ca11d4b..00e57803 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ Some more useful examples are listed below. - Joint text and image embeddings with [CLIP](clip). - Text generation from image and text inputs with [LLaVA](llava). +- Image segmentation with [Segment Anything (SAM)](segment_anything). ### Other Models diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py index 1f858d70..5d524580 100644 --- a/llms/mlx_lm/gguf.py +++ b/llms/mlx_lm/gguf.py @@ -59,7 +59,7 @@ class HfVocab: for token_id in range(self.vocab_size_base): if token_id in self.added_tokens_ids: continue - token_text = reverse_vocab[token_id].encode("utf-8") + token_text = reverse_vocab[token_id] yield token_text, self.get_token_score(token_id), self.get_token_type( token_id, token_text, self.special_ids ) @@ -67,7 +67,7 @@ class HfVocab: def get_token_type( self, token_id: int, token_text: bytes, special_ids: Set[int] ) -> TokenType: - if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text): + if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text.encode("utf-8")): return TokenType.BYTE return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL @@ -84,7 +84,7 @@ class HfVocab: else: toktype = TokenType.USER_DEFINED score = -1000.0 - yield text.encode("utf-8"), score, toktype + yield text, score, toktype def has_newline_token(self): return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab diff --git a/whisper/convert.py b/whisper/convert.py index 37825d6c..85ce5fba 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -105,6 +105,88 @@ def available_models() -> List[str]: return list(_MODELS.keys()) +def hf_to_pt(weights, config): + config = { + "n_mels": config["num_mel_bins"], + "n_audio_ctx": config["max_source_positions"], + "n_audio_state": config["d_model"], + "n_audio_head": config["encoder_attention_heads"], + "n_audio_layer": config["encoder_layers"], + "n_vocab": config["vocab_size"], + "n_text_ctx": config["max_target_positions"], + "n_text_state": config["d_model"], + "n_text_head": config["decoder_attention_heads"], + "n_text_layer": config["decoder_layers"], + } + + def remap(k): + k = k.replace("model.", "") + k = k.replace(".layers", ".blocks") + k = k.replace(".self_attn", ".attn") + k = k.replace(".attn_layer_norm", ".attn_ln") + k = k.replace(".encoder_attn.", ".cross_attn.") + k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln") + k = k.replace(".final_layer_norm", ".mlp_ln") + k = k.replace(".q_proj", ".query") + k = k.replace(".k_proj", ".key") + k = k.replace(".v_proj", ".value") + k = k.replace(".out_proj", ".out") + k = k.replace(".fc1", ".mlp1") + k = k.replace(".fc2", ".mlp2") + k = k.replace("embed_positions.weight", "positional_embedding") + k = k.replace("decoder.embed_tokens", "decoder.token_embedding") + k = k.replace("encoder.layer_norm", "encoder.ln_post") + k = k.replace("decoder.layer_norm", "decoder.ln") + return k + + # token embeddings are shared with output projection + weights.pop("proj_out.weight", None) + weights = {remap(k): v for k, v in weights.items()} + return weights, config + + +def load_torch_weights_and_config( + name_or_path: str, + download_root: str = None, +): + if download_root is None: + download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") + + # todo: accept alignment_heads of local Pytorch checkpoint + alignment_heads = None + if name_or_path in _MODELS: + alignment_heads = _ALIGNMENT_HEADS[name_or_path] + name_or_path = _download(_MODELS[name_or_path], download_root) + elif not Path(name_or_path).exists(): + # Try downloading from HF + from huggingface_hub import snapshot_download + + name_or_path = snapshot_download( + repo_id=name_or_path, + allow_patterns=["*.json", "pytorch_model.bin", "*.txt"], + ) + else: + raise RuntimeError( + f"Model {name_or_path} is not found in {available_models()}," + "on Hugging Face or as a local path." + ) + + if name_or_path.endswith(".pt"): + checkpoint = torch.load(name_or_path, map_location="cpu") + weights, config = checkpoint["model_state_dict"], checkpoint["dims"] + else: + name_or_path = Path(name_or_path) + weights = torch.load( + name_or_path / "pytorch_model.bin", + map_location="cpu", + ) + with open(name_or_path / "config.json", "r") as fp: + config = json.load(fp) + weights, config = hf_to_pt(weights, config) + + return weights, config, alignment_heads + + def load_torch_model( name_or_path: str, download_root: str = None, @@ -115,7 +197,8 @@ def load_torch_model( Parameters ---------- name_or_path : str - one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format + one of the official model names listed by `whisper.available_models()` or + a local Pytorch checkpoint which is in the original OpenAI format download_root: str path to download the model files; by default, it uses "~/.cache/whisper" @@ -128,22 +211,12 @@ def load_torch_model( if download_root is None: download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") - # todo: accept alignment_heads of local Pytorch checkpoint - alignment_heads = None - if name_or_path in _MODELS: - alignment_heads = _ALIGNMENT_HEADS[name_or_path] - name_or_path = _download(_MODELS[name_or_path], download_root) - elif not Path(name_or_path).is_file(): - raise RuntimeError( - f"Model {name_or_path} is neither found in {available_models()} nor as a local path" - ) - - with open(name_or_path, "rb") as fp: - checkpoint = torch.load(fp) - - dims = torch_whisper.ModelDimensions(**checkpoint["dims"]) + weights, config, alignment_heads = load_torch_weights_and_config( + name_or_path, download_root + ) + dims = torch_whisper.ModelDimensions(**config) model = torch_whisper.Whisper(dims) - model.load_state_dict(checkpoint["model_state_dict"]) + model.load_state_dict(weights) if alignment_heads is not None: model.set_alignment_heads(alignment_heads) @@ -151,59 +224,26 @@ def load_torch_model( return model -def convert(model, rules=None): - params = {} - if rules is not None and type(model) in rules: - out = rules[type(model)](model, rules) - return out - if isinstance(model, torch.Tensor): - return mx.array(model.detach().numpy()) - if isinstance(model, torch.nn.ModuleList): - return [convert(n, rules) for n in model.children()] - if isinstance(model, torch.nn.Conv1d): - return { - "weight": convert(model.weight).transpose(0, 2, 1), - "bias": convert(model.bias), - } - for k, n in model.named_children(): - if k in rules: - params.update(rules[k](n, rules)) - else: - params[k] = convert(n, rules) - for k, p in model.named_parameters(recurse=False): - params[k] = convert(p) - return params +def convert(name_or_path: str, dtype: mx.Dtype = mx.float16): + def remap(key, value): + key = key.replace("mlp.0", "mlp1") + key = key.replace("mlp.2", "mlp2") + if "conv" in key and value.ndim == 3: + value = value.swapaxes(1, 2) + return key, mx.array(value.detach()).astype(dtype) + weights, config, alignment_heads = load_torch_weights_and_config(name_or_path) + weights.pop("encoder.positional_embedding", None) + weights = dict(remap(k, v) for k, v in weights.items()) -def torch_to_mlx( - torch_model: torch_whisper.Whisper, - dtype: mx.Dtype = mx.float16, -) -> Whisper: - def convert_rblock(model, rules): - children = dict(model.named_children()) - mlp = list(children.pop("mlp").children()) - params = { - "mlp1": convert(mlp[0], rules), - "mlp2": convert(mlp[-1], rules), - } - for k, n in children.items(): - params[k] = convert(n, rules) - return params + model_dims = ModelDimensions(**config) + model = Whisper(model_dims, dtype) + model.load_weights(list(weights.items()), strict=False) - rules = { - torch_whisper.ResidualAttentionBlock: convert_rblock, - } + if alignment_heads is not None: + model.set_alignment_heads(alignment_heads) - params = convert(torch_model, rules) - - mlx_model = Whisper(torch_model.dims, dtype) - params = tree_map(lambda p: p.astype(dtype), params) - mlx_model.update(params) - - if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None: - mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy()) - - return mlx_model + return model def upload_to_hub(path: str, name: str, torch_name_or_path: str): @@ -292,13 +332,13 @@ if __name__ == "__main__": action="store_true", ) parser.add_argument( - "--q_group_size", + "--q-group-size", help="Group size for quantization.", type=int, default=64, ) parser.add_argument( - "--q_bits", + "--q-bits", help="Bits per weight for quantization.", type=int, default=4, @@ -318,7 +358,7 @@ if __name__ == "__main__": dtype = getattr(mx, args.dtype) print("[INFO] Loading") - model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype) + model = convert(args.torch_name_or_path, dtype) config = asdict(model.dims) weights = dict(tree_flatten(model.parameters())) diff --git a/whisper/mlx_whisper/version.py b/whisper/mlx_whisper/version.py index 87ee07a7..ae3cfb71 100644 --- a/whisper/mlx_whisper/version.py +++ b/whisper/mlx_whisper/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/whisper/test.py b/whisper/test.py index ce559251..f0acb3cd 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -13,7 +13,7 @@ import mlx_whisper.decoding as decoding import mlx_whisper.load_models as load_models import numpy as np import torch -from convert import load_torch_model, quantize, torch_to_mlx +from convert import convert, load_torch_model, quantize from mlx.utils import tree_flatten MODEL_NAME = "tiny" @@ -41,12 +41,12 @@ def _save_model(save_dir, weights, config): def load_torch_and_mlx(): torch_model = load_torch_model(MODEL_NAME) - fp32_model = torch_to_mlx(torch_model, dtype=mx.float32) + fp32_model = convert(MODEL_NAME, dtype=mx.float32) config = asdict(fp32_model.dims) weights = dict(tree_flatten(fp32_model.parameters())) _save_model(MLX_FP32_MODEL_PATH, weights, config) - fp16_model = torch_to_mlx(torch_model, dtype=mx.float16) + fp16_model = convert(MODEL_NAME, dtype=mx.float16) config = asdict(fp16_model.dims) weights = dict(tree_flatten(fp16_model.parameters())) _save_model(MLX_FP16_MODEL_PATH, weights, config)