fix lazy load

This commit is contained in:
Awni Hannun 2025-01-14 13:14:48 -08:00
parent 617f9289b9
commit 65b792d7c0
2 changed files with 8 additions and 5 deletions

View File

@ -15,6 +15,7 @@ DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256 DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def share_message(world, prompt): def share_message(world, prompt):
if world.size() == 1: if world.size() == 1:
return prompt return prompt
@ -86,7 +87,10 @@ def main():
) )
print(f"Node {world.rank()} of {world.size()}", flush=True) print(f"Node {world.rank()} of {world.size()}", flush=True)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True) print(
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
flush=True,
)
world.barrier() world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size) prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True: while True:
@ -119,4 +123,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -749,7 +749,7 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model, config = load_model(model_path, sequential_load, lazy) model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
if adapter_path is not None: if adapter_path is not None:
model = load_adapters(model, adapter_path) model = load_adapters(model, adapter_path)
model.eval() model.eval()
@ -763,7 +763,7 @@ def load(
def fetch_from_hub( def fetch_from_hub(
model_path: Path, lazy: bool = False model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model, config = load_model(model_path, lazy) model, config = load_model(model_path, lazy=lazy)
tokenizer = load_tokenizer( tokenizer = load_tokenizer(
model_path, eos_token_ids=config.get("eos_token_id", None) model_path, eos_token_ids=config.get("eos_token_id", None)
) )