mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 02:48:07 +08:00 
			
		
		
		
	Whisper updates to allow HF models (#923)
* simplify conversion and update convert for HF models * use npz for compat * fixes * fixes * fix gguf * allow user supplied path
This commit is contained in:
		| @@ -105,6 +105,88 @@ def available_models() -> List[str]: | ||||
|     return list(_MODELS.keys()) | ||||
|  | ||||
|  | ||||
| def hf_to_pt(weights, config): | ||||
|     config = { | ||||
|         "n_mels": config["num_mel_bins"], | ||||
|         "n_audio_ctx": config["max_source_positions"], | ||||
|         "n_audio_state": config["d_model"], | ||||
|         "n_audio_head": config["encoder_attention_heads"], | ||||
|         "n_audio_layer": config["encoder_layers"], | ||||
|         "n_vocab": config["vocab_size"], | ||||
|         "n_text_ctx": config["max_target_positions"], | ||||
|         "n_text_state": config["d_model"], | ||||
|         "n_text_head": config["decoder_attention_heads"], | ||||
|         "n_text_layer": config["decoder_layers"], | ||||
|     } | ||||
|  | ||||
|     def remap(k): | ||||
|         k = k.replace("model.", "") | ||||
|         k = k.replace(".layers", ".blocks") | ||||
|         k = k.replace(".self_attn", ".attn") | ||||
|         k = k.replace(".attn_layer_norm", ".attn_ln") | ||||
|         k = k.replace(".encoder_attn.", ".cross_attn.") | ||||
|         k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln") | ||||
|         k = k.replace(".final_layer_norm", ".mlp_ln") | ||||
|         k = k.replace(".q_proj", ".query") | ||||
|         k = k.replace(".k_proj", ".key") | ||||
|         k = k.replace(".v_proj", ".value") | ||||
|         k = k.replace(".out_proj", ".out") | ||||
|         k = k.replace(".fc1", ".mlp1") | ||||
|         k = k.replace(".fc2", ".mlp2") | ||||
|         k = k.replace("embed_positions.weight", "positional_embedding") | ||||
|         k = k.replace("decoder.embed_tokens", "decoder.token_embedding") | ||||
|         k = k.replace("encoder.layer_norm", "encoder.ln_post") | ||||
|         k = k.replace("decoder.layer_norm", "decoder.ln") | ||||
|         return k | ||||
|  | ||||
|     # token embeddings are shared with output projection | ||||
|     weights.pop("proj_out.weight", None) | ||||
|     weights = {remap(k): v for k, v in weights.items()} | ||||
|     return weights, config | ||||
|  | ||||
|  | ||||
| def load_torch_weights_and_config( | ||||
|     name_or_path: str, | ||||
|     download_root: str = None, | ||||
| ): | ||||
|     if download_root is None: | ||||
|         download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") | ||||
|  | ||||
|     # todo: accept alignment_heads of local Pytorch checkpoint | ||||
|     alignment_heads = None | ||||
|     if name_or_path in _MODELS: | ||||
|         alignment_heads = _ALIGNMENT_HEADS[name_or_path] | ||||
|         name_or_path = _download(_MODELS[name_or_path], download_root) | ||||
|     elif not Path(name_or_path).exists(): | ||||
|         # Try downloading from HF | ||||
|         from huggingface_hub import snapshot_download | ||||
|  | ||||
|         name_or_path = snapshot_download( | ||||
|             repo_id=name_or_path, | ||||
|             allow_patterns=["*.json", "pytorch_model.bin", "*.txt"], | ||||
|         ) | ||||
|     else: | ||||
|         raise RuntimeError( | ||||
|             f"Model {name_or_path} is not found in {available_models()}," | ||||
|             "on Hugging Face or as a local path." | ||||
|         ) | ||||
|  | ||||
|     if name_or_path.endswith(".pt"): | ||||
|         checkpoint = torch.load(name_or_path, map_location="cpu") | ||||
|         weights, config = checkpoint["model_state_dict"], checkpoint["dims"] | ||||
|     else: | ||||
|         name_or_path = Path(name_or_path) | ||||
|         weights = torch.load( | ||||
|             name_or_path / "pytorch_model.bin", | ||||
|             map_location="cpu", | ||||
|         ) | ||||
|         with open(name_or_path / "config.json", "r") as fp: | ||||
|             config = json.load(fp) | ||||
|         weights, config = hf_to_pt(weights, config) | ||||
|  | ||||
|     return weights, config, alignment_heads | ||||
|  | ||||
|  | ||||
| def load_torch_model( | ||||
|     name_or_path: str, | ||||
|     download_root: str = None, | ||||
| @@ -115,7 +197,8 @@ def load_torch_model( | ||||
|     Parameters | ||||
|     ---------- | ||||
|     name_or_path : str | ||||
|         one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format | ||||
|         one of the official model names listed by `whisper.available_models()` or | ||||
|         a local Pytorch checkpoint which is in the original OpenAI format | ||||
|     download_root: str | ||||
|         path to download the model files; by default, it uses "~/.cache/whisper" | ||||
|  | ||||
| @@ -128,22 +211,12 @@ def load_torch_model( | ||||
|     if download_root is None: | ||||
|         download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") | ||||
|  | ||||
|     # todo: accept alignment_heads of local Pytorch checkpoint | ||||
|     alignment_heads = None | ||||
|     if name_or_path in _MODELS: | ||||
|         alignment_heads = _ALIGNMENT_HEADS[name_or_path] | ||||
|         name_or_path = _download(_MODELS[name_or_path], download_root) | ||||
|     elif not Path(name_or_path).is_file(): | ||||
|         raise RuntimeError( | ||||
|             f"Model {name_or_path} is neither found in {available_models()} nor as a local path" | ||||
|         ) | ||||
|  | ||||
|     with open(name_or_path, "rb") as fp: | ||||
|         checkpoint = torch.load(fp) | ||||
|  | ||||
|     dims = torch_whisper.ModelDimensions(**checkpoint["dims"]) | ||||
|     weights, config, alignment_heads = load_torch_weights_and_config( | ||||
|         name_or_path, download_root | ||||
|     ) | ||||
|     dims = torch_whisper.ModelDimensions(**config) | ||||
|     model = torch_whisper.Whisper(dims) | ||||
|     model.load_state_dict(checkpoint["model_state_dict"]) | ||||
|     model.load_state_dict(weights) | ||||
|  | ||||
|     if alignment_heads is not None: | ||||
|         model.set_alignment_heads(alignment_heads) | ||||
| @@ -151,59 +224,26 @@ def load_torch_model( | ||||
|     return model | ||||
|  | ||||
|  | ||||
| def convert(model, rules=None): | ||||
|     params = {} | ||||
|     if rules is not None and type(model) in rules: | ||||
|         out = rules[type(model)](model, rules) | ||||
|         return out | ||||
|     if isinstance(model, torch.Tensor): | ||||
|         return mx.array(model.detach().numpy()) | ||||
|     if isinstance(model, torch.nn.ModuleList): | ||||
|         return [convert(n, rules) for n in model.children()] | ||||
|     if isinstance(model, torch.nn.Conv1d): | ||||
|         return { | ||||
|             "weight": convert(model.weight).transpose(0, 2, 1), | ||||
|             "bias": convert(model.bias), | ||||
|         } | ||||
|     for k, n in model.named_children(): | ||||
|         if k in rules: | ||||
|             params.update(rules[k](n, rules)) | ||||
|         else: | ||||
|             params[k] = convert(n, rules) | ||||
|     for k, p in model.named_parameters(recurse=False): | ||||
|         params[k] = convert(p) | ||||
|     return params | ||||
| def convert(name_or_path: str, dtype: mx.Dtype = mx.float16): | ||||
|     def remap(key, value): | ||||
|         key = key.replace("mlp.0", "mlp1") | ||||
|         key = key.replace("mlp.2", "mlp2") | ||||
|         if "conv" in key and value.ndim == 3: | ||||
|             value = value.swapaxes(1, 2) | ||||
|         return key, mx.array(value.detach()).astype(dtype) | ||||
|  | ||||
|     weights, config, alignment_heads = load_torch_weights_and_config(name_or_path) | ||||
|     weights.pop("encoder.positional_embedding", None) | ||||
|     weights = dict(remap(k, v) for k, v in weights.items()) | ||||
|  | ||||
| def torch_to_mlx( | ||||
|     torch_model: torch_whisper.Whisper, | ||||
|     dtype: mx.Dtype = mx.float16, | ||||
| ) -> Whisper: | ||||
|     def convert_rblock(model, rules): | ||||
|         children = dict(model.named_children()) | ||||
|         mlp = list(children.pop("mlp").children()) | ||||
|         params = { | ||||
|             "mlp1": convert(mlp[0], rules), | ||||
|             "mlp2": convert(mlp[-1], rules), | ||||
|         } | ||||
|         for k, n in children.items(): | ||||
|             params[k] = convert(n, rules) | ||||
|         return params | ||||
|     model_dims = ModelDimensions(**config) | ||||
|     model = Whisper(model_dims, dtype) | ||||
|     model.load_weights(list(weights.items()), strict=False) | ||||
|  | ||||
|     rules = { | ||||
|         torch_whisper.ResidualAttentionBlock: convert_rblock, | ||||
|     } | ||||
|     if alignment_heads is not None: | ||||
|         model.set_alignment_heads(alignment_heads) | ||||
|  | ||||
|     params = convert(torch_model, rules) | ||||
|  | ||||
|     mlx_model = Whisper(torch_model.dims, dtype) | ||||
|     params = tree_map(lambda p: p.astype(dtype), params) | ||||
|     mlx_model.update(params) | ||||
|  | ||||
|     if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None: | ||||
|         mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy()) | ||||
|  | ||||
|     return mlx_model | ||||
|     return model | ||||
|  | ||||
|  | ||||
| def upload_to_hub(path: str, name: str, torch_name_or_path: str): | ||||
| @@ -292,13 +332,13 @@ if __name__ == "__main__": | ||||
|         action="store_true", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--q_group_size", | ||||
|         "--q-group-size", | ||||
|         help="Group size for quantization.", | ||||
|         type=int, | ||||
|         default=64, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--q_bits", | ||||
|         "--q-bits", | ||||
|         help="Bits per weight for quantization.", | ||||
|         type=int, | ||||
|         default=4, | ||||
| @@ -318,7 +358,7 @@ if __name__ == "__main__": | ||||
|     dtype = getattr(mx, args.dtype) | ||||
|  | ||||
|     print("[INFO] Loading") | ||||
|     model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype) | ||||
|     model = convert(args.torch_name_or_path, dtype) | ||||
|     config = asdict(model.dims) | ||||
|     weights = dict(tree_flatten(model.parameters())) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun