mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 10:58:07 +08:00 
			
		
		
		
	fix lazy load
This commit is contained in:
		| @@ -15,6 +15,7 @@ DEFAULT_SEED = 0 | |||||||
| DEFAULT_MAX_TOKENS = 256 | DEFAULT_MAX_TOKENS = 256 | ||||||
| DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" | DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" | ||||||
|  |  | ||||||
|  |  | ||||||
| def share_message(world, prompt): | def share_message(world, prompt): | ||||||
|     if world.size() == 1: |     if world.size() == 1: | ||||||
|         return prompt |         return prompt | ||||||
| @@ -30,7 +31,7 @@ def share_message(world, prompt): | |||||||
|     if world.rank() == 0: |     if world.rank() == 0: | ||||||
|         prompt = mx.array(prompt) |         prompt = mx.array(prompt) | ||||||
|     else: |     else: | ||||||
|         prompt = mx.array([0]*len(prompt)) |         prompt = mx.array([0] * len(prompt)) | ||||||
|     return mx.distributed.all_sum(size, stream=mx.cpu).tolist() |     return mx.distributed.all_sum(size, stream=mx.cpu).tolist() | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -86,7 +87,10 @@ def main(): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     print(f"Node {world.rank()} of {world.size()}", flush=True) |     print(f"Node {world.rank()} of {world.size()}", flush=True) | ||||||
|     print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True) |     print( | ||||||
|  |         f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", | ||||||
|  |         flush=True, | ||||||
|  |     ) | ||||||
|     world.barrier() |     world.barrier() | ||||||
|     prompt_cache = make_prompt_cache(model, args.max_kv_size) |     prompt_cache = make_prompt_cache(model, args.max_kv_size) | ||||||
|     while True: |     while True: | ||||||
| @@ -119,4 +123,3 @@ def main(): | |||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     main() |     main() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -749,7 +749,7 @@ def load( | |||||||
|     """ |     """ | ||||||
|     model_path = get_model_path(path_or_hf_repo) |     model_path = get_model_path(path_or_hf_repo) | ||||||
|  |  | ||||||
|     model, config = load_model(model_path, sequential_load, lazy) |     model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load) | ||||||
|     if adapter_path is not None: |     if adapter_path is not None: | ||||||
|         model = load_adapters(model, adapter_path) |         model = load_adapters(model, adapter_path) | ||||||
|         model.eval() |         model.eval() | ||||||
| @@ -763,7 +763,7 @@ def load( | |||||||
| def fetch_from_hub( | def fetch_from_hub( | ||||||
|     model_path: Path, lazy: bool = False |     model_path: Path, lazy: bool = False | ||||||
| ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: | ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: | ||||||
|     model, config = load_model(model_path, lazy) |     model, config = load_model(model_path, lazy=lazy) | ||||||
|     tokenizer = load_tokenizer( |     tokenizer = load_tokenizer( | ||||||
|         model_path, eos_token_ids=config.get("eos_token_id", None) |         model_path, eos_token_ids=config.get("eos_token_id", None) | ||||||
|     ) |     ) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun