| 
									
										
										
										
											2023-11-30 11:08:53 -08:00
										 |  |  | # Copyright © 2023 Apple Inc. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | import json | 
					
						
							|  |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import mlx.core as mx | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | import mlx.nn as nn | 
					
						
							| 
									
										
										
										
											2024-01-08 19:50:00 +05:30
										 |  |  | from huggingface_hub import snapshot_download | 
					
						
							| 
									
										
										
										
											2024-01-12 13:45:30 -08:00
										 |  |  | from mlx.utils import tree_unflatten | 
					
						
							| 
									
										
										
										
											2024-01-08 19:50:00 +05:30
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | from . import whisper | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | def load_model( | 
					
						
							| 
									
										
										
										
											2024-01-08 19:50:00 +05:30
										 |  |  |     path_or_hf_repo: str, | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     dtype: mx.Dtype = mx.float32, | 
					
						
							|  |  |  | ) -> whisper.Whisper: | 
					
						
							| 
									
										
										
										
											2024-01-08 19:50:00 +05:30
										 |  |  |     model_path = Path(path_or_hf_repo) | 
					
						
							|  |  |  |     if not model_path.exists(): | 
					
						
							| 
									
										
										
										
											2024-01-12 13:45:30 -08:00
										 |  |  |         model_path = Path(snapshot_download(repo_id=path_or_hf_repo)) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     with open(str(model_path / "config.json"), "r") as f: | 
					
						
							|  |  |  |         config = json.loads(f.read()) | 
					
						
							|  |  |  |         config.pop("model_type", None) | 
					
						
							|  |  |  |         quantization = config.pop("quantization", None) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     model_args = whisper.ModelDimensions(**config) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-01 10:52:28 -07:00
										 |  |  |     wf = model_path / "weights.safetensors" | 
					
						
							|  |  |  |     if not wf.exists(): | 
					
						
							|  |  |  |         wf = model_path / "weights.npz" | 
					
						
							|  |  |  |     weights = mx.load(str(wf)) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     model = whisper.Whisper(model_args, dtype) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     if quantization is not None: | 
					
						
							| 
									
										
										
										
											2024-04-19 20:07:11 -07:00
										 |  |  |         class_predicate = ( | 
					
						
							|  |  |  |             lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) | 
					
						
							|  |  |  |             and f"{p}.scales" in weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         nn.quantize(model, **quantization, class_predicate=class_predicate) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-19 20:07:11 -07:00
										 |  |  |     weights = tree_unflatten(list(weights.items())) | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     model.update(weights) | 
					
						
							|  |  |  |     mx.eval(model.parameters()) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |     return model |