mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix lazy load
This commit is contained in:
parent
617f9289b9
commit
65b792d7c0
@ -15,6 +15,7 @@ DEFAULT_SEED = 0
|
|||||||
DEFAULT_MAX_TOKENS = 256
|
DEFAULT_MAX_TOKENS = 256
|
||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
def share_message(world, prompt):
|
def share_message(world, prompt):
|
||||||
if world.size() == 1:
|
if world.size() == 1:
|
||||||
return prompt
|
return prompt
|
||||||
@ -86,7 +87,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
||||||
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True)
|
print(
|
||||||
|
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
world.barrier()
|
world.barrier()
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
@ -119,4 +123,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
@ -749,7 +749,7 @@ def load(
|
|||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model, config = load_model(model_path, sequential_load, lazy)
|
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = load_adapters(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -763,7 +763,7 @@ def load(
|
|||||||
def fetch_from_hub(
|
def fetch_from_hub(
|
||||||
model_path: Path, lazy: bool = False
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
model, config = load_model(model_path, lazy)
|
model, config = load_model(model_path, lazy=lazy)
|
||||||
tokenizer = load_tokenizer(
|
tokenizer = load_tokenizer(
|
||||||
model_path, eos_token_ids=config.get("eos_token_id", None)
|
model_path, eos_token_ids=config.get("eos_token_id", None)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user