mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 18:36:38 +08:00
Make the chat distributed
This commit is contained in:
parent
026362e0f8
commit
617f9289b9
@ -15,6 +15,24 @@ DEFAULT_SEED = 0
|
|||||||
DEFAULT_MAX_TOKENS = 256
|
DEFAULT_MAX_TOKENS = 256
|
||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
def share_message(world, prompt):
|
||||||
|
if world.size() == 1:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
if world.rank() == 0:
|
||||||
|
size = mx.array([len(prompt)])
|
||||||
|
else:
|
||||||
|
size = mx.array([0])
|
||||||
|
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
|
||||||
|
if size == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if world.rank() == 0:
|
||||||
|
prompt = mx.array(prompt)
|
||||||
|
else:
|
||||||
|
prompt = mx.array([0]*len(prompt))
|
||||||
|
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
"""Set up and return the argument parser."""
|
"""Set up and return the argument parser."""
|
||||||
@ -54,6 +72,7 @@ def setup_arg_parser():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
world = mx.distributed.init()
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -63,16 +82,27 @@ def main():
|
|||||||
args.model,
|
args.model,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
|
sequential_load=mx.distributed.init().size() > 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
||||||
|
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True)
|
||||||
|
world.barrier()
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
query = input(">> ")
|
if world.rank() == 0:
|
||||||
if query == "q":
|
query = input(">> ")
|
||||||
|
if query == "q":
|
||||||
|
prompt = []
|
||||||
|
else:
|
||||||
|
messages = [{"role": "user", "content": query}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = share_message(world, prompt)
|
||||||
|
if len(prompt) == 0:
|
||||||
break
|
break
|
||||||
messages = [{"role": "user", "content": query}]
|
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
|
||||||
for response in stream_generate(
|
for response in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -81,9 +111,12 @@ def main():
|
|||||||
sampler=make_sampler(args.temp, args.top_p),
|
sampler=make_sampler(args.temp, args.top_p),
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
):
|
):
|
||||||
print(response.text, flush=True, end="")
|
if world.rank() == 0:
|
||||||
print()
|
print(response, flush=True, end="")
|
||||||
|
if world.rank() == 0:
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user