5 Commits

Author SHA1 Message Date
Awni Hannun
65b792d7c0 fix lazy load 2025-02-06 07:28:59 -08:00
Angelos Katharopoulos
617f9289b9 Make the chat distributed 2025-02-06 07:28:59 -08:00
Angelos Katharopoulos
026362e0f8 Remove async eval and add sequential load 2025-02-06 07:28:58 -08:00
Angelos Katharopoulos
a0ce0594f6 Temporarily remove async_eval 2025-02-06 07:28:03 -08:00
Angelos Katharopoulos
d77840207c Start distributed inference for llama models 2025-02-06 07:28:03 -08:00
4 changed files with 99 additions and 13 deletions

View File

@@ -16,6 +16,25 @@ DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def share_message(world, prompt):
if world.size() == 1:
return prompt
if world.rank() == 0:
size = mx.array([len(prompt)])
else:
size = mx.array([0])
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
if size == 0:
return []
if world.rank() == 0:
prompt = mx.array(prompt)
else:
prompt = mx.array([0] * len(prompt))
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
@@ -54,6 +73,7 @@ def setup_arg_parser():
def main():
world = mx.distributed.init()
parser = setup_arg_parser()
args = parser.parse_args()
@@ -63,16 +83,30 @@ def main():
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
sequential_load=mx.distributed.init().size() > 1,
)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
print(f"Node {world.rank()} of {world.size()}", flush=True)
print(
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
flush=True,
)
world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
if world.rank() == 0:
query = input(">> ")
if query == "q":
break
prompt = []
else:
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
prompt = share_message(world, prompt)
if len(prompt) == 0:
break
for response in stream_generate(
model,
tokenizer,
@@ -81,7 +115,9 @@ def main():
sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache,
):
print(response.text, flush=True, end="")
if world.rank() == 0:
print(response, flush=True, end="")
if world.rank() == 0:
print()

View File

@@ -191,6 +191,7 @@ def main():
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
sequential_load=mx.distributed.init().size() > 1,
)
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
@@ -234,13 +235,17 @@ def main():
else:
draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
world = mx.distributed.init()
print(f"Node {world.rank()} of {world.size()}", flush=True)
world.barrier()
response = generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
verbose=args.verbose,
sampler=sampler,
verbose=args.verbose and world.rank() == 0,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
@@ -249,8 +254,10 @@ def main():
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
if not args.verbose:
if not args.verbose and mx.distributed.init().rank() == 0:
print(response)
mx.synchronize()
if __name__ == "__main__":

View File

@@ -200,6 +200,36 @@ class Model(nn.Module):
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
def all_to_sharded(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
else:
return nn.AllToShardedLinear.from_linear(l, group)
def sharded_to_all(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
else:
return nn.ShardedToAllLinear.from_linear(l, group)
N = group.size()
for layer in self.model.layers:
# Shard the self attention
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
layer.self_attn.n_heads //= N
layer.self_attn.n_kv_heads //= N
# Shard the MLP
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
@property
def layers(self):
return self.model.layers

View File

@@ -306,12 +306,12 @@ def generate_step(
y, logprobs = _step(y)
mx.async_eval(y, logprobs)
mx.eval(y, logprobs)
n = 0
while True:
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
mx.eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
@@ -628,6 +628,7 @@ def load_model(
model_path: Path,
lazy: bool = False,
strict: bool = True,
sequential_load: bool = False,
model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
@@ -699,7 +700,16 @@ def load_model(
model.load_weights(list(weights.items()), strict=strict)
if mx.distributed.init().size() > 1:
if not hasattr(model, "shard"):
raise RuntimeError("Model doesn't support distributed inference.")
model.shard()
if not lazy:
weights.clear()
if sequential_load:
for layer in model.layers:
mx.eval(layer.parameters())
mx.eval(model.parameters())
model.eval()
@@ -712,6 +722,7 @@ def load(
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
sequential_load: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
@@ -727,6 +738,8 @@ def load(
lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
sequential_load (bool): If True then load each layer sequentially to
ensure that we are not wasting memory.
Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
@@ -736,7 +749,7 @@ def load(
"""
model_path = get_model_path(path_or_hf_repo)
model, config = load_model(model_path, lazy)
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
if adapter_path is not None:
model = load_adapters(model, adapter_path)
model.eval()
@@ -750,7 +763,7 @@ def load(
def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model, config = load_model(model_path, lazy)
model, config = load_model(model_path, lazy=lazy)
tokenizer = load_tokenizer(
model_path, eos_token_ids=config.get("eos_token_id", None)
)