From 617f9289b91c0c7dceea48265554cb747c08a0f0 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 5 Nov 2024 13:09:34 -0800 Subject: [PATCH] Make the chat distributed --- llms/mlx_lm/chat.py | 47 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index e52ad10d..380db082 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -15,6 +15,24 @@ DEFAULT_SEED = 0 DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +def share_message(world, prompt): + if world.size() == 1: + return prompt + + if world.rank() == 0: + size = mx.array([len(prompt)]) + else: + size = mx.array([0]) + size = mx.distributed.all_sum(size, stream=mx.cpu).item() + if size == 0: + return [] + + if world.rank() == 0: + prompt = mx.array(prompt) + else: + prompt = mx.array([0]*len(prompt)) + return mx.distributed.all_sum(size, stream=mx.cpu).tolist() + def setup_arg_parser(): """Set up and return the argument parser.""" @@ -54,6 +72,7 @@ def setup_arg_parser(): def main(): + world = mx.distributed.init() parser = setup_arg_parser() args = parser.parse_args() @@ -63,16 +82,27 @@ def main(): args.model, adapter_path=args.adapter_path, tokenizer_config={"trust_remote_code": True}, + sequential_load=mx.distributed.init().size() > 1, ) - print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") + print(f"Node {world.rank()} of {world.size()}", flush=True) + print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True) + world.barrier() prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: - query = input(">> ") - if query == "q": + if world.rank() == 0: + query = input(">> ") + if query == "q": + prompt = [] + else: + messages = [{"role": "user", "content": query}] + prompt = tokenizer.apply_chat_template( + messages, add_generation_prompt=True + ) + + prompt = share_message(world, prompt) + if len(prompt) == 0: break - messages = [{"role": "user", "content": query}] - prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) for response in stream_generate( model, tokenizer, @@ -81,9 +111,12 @@ def main(): sampler=make_sampler(args.temp, args.top_p), prompt_cache=prompt_cache, ): - print(response.text, flush=True, end="") - print() + if world.rank() == 0: + print(response, flush=True, end="") + if world.rank() == 0: + print() if __name__ == "__main__": main() +