5 Commits

Author SHA1 Message Date
Awni Hannun
65b792d7c0 fix lazy load 2025-02-06 07:28:59 -08:00
Angelos Katharopoulos
617f9289b9 Make the chat distributed 2025-02-06 07:28:59 -08:00
Angelos Katharopoulos
026362e0f8 Remove async eval and add sequential load 2025-02-06 07:28:58 -08:00
Angelos Katharopoulos
a0ce0594f6 Temporarily remove async_eval 2025-02-06 07:28:03 -08:00
Angelos Katharopoulos
d77840207c Start distributed inference for llama models 2025-02-06 07:28:03 -08:00
4 changed files with 99 additions and 13 deletions

View File

@@ -16,6 +16,25 @@ DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def share_message(world, prompt):
if world.size() == 1:
return prompt
if world.rank() == 0:
size = mx.array([len(prompt)])
else:
size = mx.array([0])
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
if size == 0:
return []
if world.rank() == 0:
prompt = mx.array(prompt)
else:
prompt = mx.array([0] * len(prompt))
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
def setup_arg_parser(): def setup_arg_parser():
"""Set up and return the argument parser.""" """Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM") parser = argparse.ArgumentParser(description="Chat with an LLM")
@@ -54,6 +73,7 @@ def setup_arg_parser():
def main(): def main():
world = mx.distributed.init()
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
@@ -63,16 +83,30 @@ def main():
args.model, args.model,
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True}, tokenizer_config={"trust_remote_code": True},
sequential_load=mx.distributed.init().size() > 1,
) )
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") print(f"Node {world.rank()} of {world.size()}", flush=True)
print(
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
flush=True,
)
world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size) prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True: while True:
if world.rank() == 0:
query = input(">> ") query = input(">> ")
if query == "q": if query == "q":
break prompt = []
else:
messages = [{"role": "user", "content": query}] messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
prompt = share_message(world, prompt)
if len(prompt) == 0:
break
for response in stream_generate( for response in stream_generate(
model, model,
tokenizer, tokenizer,
@@ -81,7 +115,9 @@ def main():
sampler=make_sampler(args.temp, args.top_p), sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
): ):
print(response.text, flush=True, end="") if world.rank() == 0:
print(response, flush=True, end="")
if world.rank() == 0:
print() print()

View File

@@ -191,6 +191,7 @@ def main():
model_path, model_path,
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
sequential_load=mx.distributed.init().size() > 1,
) )
for eos_token in args.extra_eos_token: for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token) tokenizer.add_eos_token(eos_token)
@@ -234,13 +235,17 @@ def main():
else: else:
draft_model = None draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
world = mx.distributed.init()
print(f"Node {world.rank()} of {world.size()}", flush=True)
world.barrier()
response = generate( response = generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
verbose=args.verbose,
sampler=sampler, sampler=sampler,
verbose=args.verbose and world.rank() == 0,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,
@@ -249,8 +254,10 @@ def main():
draft_model=draft_model, draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens, num_draft_tokens=args.num_draft_tokens,
) )
if not args.verbose:
if not args.verbose and mx.distributed.init().rank() == 0:
print(response) print(response)
mx.synchronize()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -200,6 +200,36 @@ class Model(nn.Module):
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
} }
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
def all_to_sharded(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
else:
return nn.AllToShardedLinear.from_linear(l, group)
def sharded_to_all(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
else:
return nn.ShardedToAllLinear.from_linear(l, group)
N = group.size()
for layer in self.model.layers:
# Shard the self attention
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
layer.self_attn.n_heads //= N
layer.self_attn.n_kv_heads //= N
# Shard the MLP
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers

View File

@@ -306,12 +306,12 @@ def generate_step(
y, logprobs = _step(y) y, logprobs = _step(y)
mx.async_eval(y, logprobs) mx.eval(y, logprobs)
n = 0 n = 0
while True: while True:
if n != max_tokens: if n != max_tokens:
next_y, next_logprobs = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs) mx.eval(next_y, next_logprobs)
if n == 0: if n == 0:
mx.eval(y) mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
@@ -628,6 +628,7 @@ def load_model(
model_path: Path, model_path: Path,
lazy: bool = False, lazy: bool = False,
strict: bool = True, strict: bool = True,
sequential_load: bool = False,
model_config: dict = {}, model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module: ) -> nn.Module:
@@ -699,7 +700,16 @@ def load_model(
model.load_weights(list(weights.items()), strict=strict) model.load_weights(list(weights.items()), strict=strict)
if mx.distributed.init().size() > 1:
if not hasattr(model, "shard"):
raise RuntimeError("Model doesn't support distributed inference.")
model.shard()
if not lazy: if not lazy:
weights.clear()
if sequential_load:
for layer in model.layers:
mx.eval(layer.parameters())
mx.eval(model.parameters()) mx.eval(model.parameters())
model.eval() model.eval()
@@ -712,6 +722,7 @@ def load(
model_config={}, model_config={},
adapter_path: Optional[str] = None, adapter_path: Optional[str] = None,
lazy: bool = False, lazy: bool = False,
sequential_load: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]: ) -> Tuple[nn.Module, TokenizerWrapper]:
""" """
Load the model and tokenizer from a given path or a huggingface repository. Load the model and tokenizer from a given path or a huggingface repository.
@@ -727,6 +738,8 @@ def load(
lazy (bool): If ``False`` eval the model parameters to make sure they are lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
sequential_load (bool): If True then load each layer sequentially to
ensure that we are not wasting memory.
Returns: Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
@@ -736,7 +749,7 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model, config = load_model(model_path, lazy) model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
if adapter_path is not None: if adapter_path is not None:
model = load_adapters(model, adapter_path) model = load_adapters(model, adapter_path)
model.eval() model.eval()
@@ -750,7 +763,7 @@ def load(
def fetch_from_hub( def fetch_from_hub(
model_path: Path, lazy: bool = False model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model, config = load_model(model_path, lazy) model, config = load_model(model_path, lazy=lazy)
tokenizer = load_tokenizer( tokenizer = load_tokenizer(
model_path, eos_token_ids=config.get("eos_token_id", None) model_path, eos_token_ids=config.get("eos_token_id", None)
) )