Make the chat distributed

This commit is contained in:
Angelos Katharopoulos 2024-11-05 13:09:34 -08:00 committed by Awni Hannun
parent 026362e0f8
commit 617f9289b9

View File

@ -15,6 +15,24 @@ DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256 DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def share_message(world, prompt):
if world.size() == 1:
return prompt
if world.rank() == 0:
size = mx.array([len(prompt)])
else:
size = mx.array([0])
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
if size == 0:
return []
if world.rank() == 0:
prompt = mx.array(prompt)
else:
prompt = mx.array([0]*len(prompt))
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
def setup_arg_parser(): def setup_arg_parser():
"""Set up and return the argument parser.""" """Set up and return the argument parser."""
@ -54,6 +72,7 @@ def setup_arg_parser():
def main(): def main():
world = mx.distributed.init()
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
@ -63,16 +82,27 @@ def main():
args.model, args.model,
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True}, tokenizer_config={"trust_remote_code": True},
sequential_load=mx.distributed.init().size() > 1,
) )
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") print(f"Node {world.rank()} of {world.size()}", flush=True)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True)
world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size) prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True: while True:
query = input(">> ") if world.rank() == 0:
if query == "q": query = input(">> ")
if query == "q":
prompt = []
else:
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
prompt = share_message(world, prompt)
if len(prompt) == 0:
break break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate( for response in stream_generate(
model, model,
tokenizer, tokenizer,
@ -81,9 +111,12 @@ def main():
sampler=make_sampler(args.temp, args.top_p), sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
): ):
print(response.text, flush=True, end="") if world.rank() == 0:
print() print(response, flush=True, end="")
if world.rank() == 0:
print()
if __name__ == "__main__": if __name__ == "__main__":
main() main()