mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 02:48:07 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			47 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			47 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import json
 | |
| from pathlib import Path
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx.nn as nn
 | |
| from huggingface_hub import snapshot_download
 | |
| from mlx.utils import tree_unflatten
 | |
| 
 | |
| from . import whisper
 | |
| 
 | |
| 
 | |
| def load_model(
 | |
|     path_or_hf_repo: str,
 | |
|     dtype: mx.Dtype = mx.float32,
 | |
| ) -> whisper.Whisper:
 | |
|     model_path = Path(path_or_hf_repo)
 | |
|     if not model_path.exists():
 | |
|         model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
 | |
| 
 | |
|     with open(str(model_path / "config.json"), "r") as f:
 | |
|         config = json.loads(f.read())
 | |
|         config.pop("model_type", None)
 | |
|         quantization = config.pop("quantization", None)
 | |
| 
 | |
|     model_args = whisper.ModelDimensions(**config)
 | |
| 
 | |
|     wf = model_path / "weights.safetensors"
 | |
|     if not wf.exists():
 | |
|         wf = model_path / "weights.npz"
 | |
|     weights = mx.load(str(wf))
 | |
| 
 | |
|     model = whisper.Whisper(model_args, dtype)
 | |
| 
 | |
|     if quantization is not None:
 | |
|         class_predicate = (
 | |
|             lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
 | |
|             and f"{p}.scales" in weights
 | |
|         )
 | |
|         nn.quantize(model, **quantization, class_predicate=class_predicate)
 | |
| 
 | |
|     weights = tree_unflatten(list(weights.items()))
 | |
|     model.update(weights)
 | |
|     mx.eval(model.parameters())
 | |
|     return model
 | 
