mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
5 Commits
43ff302638
...
distribute
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65b792d7c0 | ||
|
|
617f9289b9 | ||
|
|
026362e0f8 | ||
|
|
a0ce0594f6 | ||
|
|
d77840207c |
@@ -16,6 +16,25 @@ DEFAULT_MAX_TOKENS = 256
|
|||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
def share_message(world, prompt):
|
||||||
|
if world.size() == 1:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
if world.rank() == 0:
|
||||||
|
size = mx.array([len(prompt)])
|
||||||
|
else:
|
||||||
|
size = mx.array([0])
|
||||||
|
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
|
||||||
|
if size == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if world.rank() == 0:
|
||||||
|
prompt = mx.array(prompt)
|
||||||
|
else:
|
||||||
|
prompt = mx.array([0] * len(prompt))
|
||||||
|
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
"""Set up and return the argument parser."""
|
"""Set up and return the argument parser."""
|
||||||
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
||||||
@@ -54,6 +73,7 @@ def setup_arg_parser():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
world = mx.distributed.init()
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -63,16 +83,30 @@ def main():
|
|||||||
args.model,
|
args.model,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
|
sequential_load=mx.distributed.init().size() > 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
||||||
|
print(
|
||||||
|
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
world.barrier()
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
query = input(">> ")
|
if world.rank() == 0:
|
||||||
if query == "q":
|
query = input(">> ")
|
||||||
|
if query == "q":
|
||||||
|
prompt = []
|
||||||
|
else:
|
||||||
|
messages = [{"role": "user", "content": query}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = share_message(world, prompt)
|
||||||
|
if len(prompt) == 0:
|
||||||
break
|
break
|
||||||
messages = [{"role": "user", "content": query}]
|
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
|
||||||
for response in stream_generate(
|
for response in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -81,8 +115,10 @@ def main():
|
|||||||
sampler=make_sampler(args.temp, args.top_p),
|
sampler=make_sampler(args.temp, args.top_p),
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
):
|
):
|
||||||
print(response.text, flush=True, end="")
|
if world.rank() == 0:
|
||||||
print()
|
print(response, flush=True, end="")
|
||||||
|
if world.rank() == 0:
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -191,6 +191,7 @@ def main():
|
|||||||
model_path,
|
model_path,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
|
sequential_load=mx.distributed.init().size() > 1,
|
||||||
)
|
)
|
||||||
for eos_token in args.extra_eos_token:
|
for eos_token in args.extra_eos_token:
|
||||||
tokenizer.add_eos_token(eos_token)
|
tokenizer.add_eos_token(eos_token)
|
||||||
@@ -234,13 +235,17 @@ def main():
|
|||||||
else:
|
else:
|
||||||
draft_model = None
|
draft_model = None
|
||||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||||
|
|
||||||
|
world = mx.distributed.init()
|
||||||
|
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
||||||
|
world.barrier()
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
verbose=args.verbose,
|
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
verbose=args.verbose and world.rank() == 0,
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
@@ -249,8 +254,10 @@ def main():
|
|||||||
draft_model=draft_model,
|
draft_model=draft_model,
|
||||||
num_draft_tokens=args.num_draft_tokens,
|
num_draft_tokens=args.num_draft_tokens,
|
||||||
)
|
)
|
||||||
if not args.verbose:
|
|
||||||
|
if not args.verbose and mx.distributed.init().rank() == 0:
|
||||||
print(response)
|
print(response)
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -200,6 +200,36 @@ class Model(nn.Module):
|
|||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
|
||||||
|
def all_to_sharded(l):
|
||||||
|
if isinstance(l, nn.QuantizedLinear):
|
||||||
|
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
|
||||||
|
else:
|
||||||
|
return nn.AllToShardedLinear.from_linear(l, group)
|
||||||
|
|
||||||
|
def sharded_to_all(l):
|
||||||
|
if isinstance(l, nn.QuantizedLinear):
|
||||||
|
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
|
||||||
|
else:
|
||||||
|
return nn.ShardedToAllLinear.from_linear(l, group)
|
||||||
|
|
||||||
|
N = group.size()
|
||||||
|
for layer in self.model.layers:
|
||||||
|
# Shard the self attention
|
||||||
|
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
|
||||||
|
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
|
||||||
|
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
|
||||||
|
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
|
||||||
|
layer.self_attn.n_heads //= N
|
||||||
|
layer.self_attn.n_kv_heads //= N
|
||||||
|
|
||||||
|
# Shard the MLP
|
||||||
|
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
|
||||||
|
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
|
||||||
|
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|||||||
@@ -306,12 +306,12 @@ def generate_step(
|
|||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y, logprobs)
|
mx.eval(y, logprobs)
|
||||||
n = 0
|
n = 0
|
||||||
while True:
|
while True:
|
||||||
if n != max_tokens:
|
if n != max_tokens:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y, next_logprobs)
|
mx.eval(next_y, next_logprobs)
|
||||||
if n == 0:
|
if n == 0:
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||||
@@ -628,6 +628,7 @@ def load_model(
|
|||||||
model_path: Path,
|
model_path: Path,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
|
sequential_load: bool = False,
|
||||||
model_config: dict = {},
|
model_config: dict = {},
|
||||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
@@ -699,7 +700,16 @@ def load_model(
|
|||||||
|
|
||||||
model.load_weights(list(weights.items()), strict=strict)
|
model.load_weights(list(weights.items()), strict=strict)
|
||||||
|
|
||||||
|
if mx.distributed.init().size() > 1:
|
||||||
|
if not hasattr(model, "shard"):
|
||||||
|
raise RuntimeError("Model doesn't support distributed inference.")
|
||||||
|
model.shard()
|
||||||
|
|
||||||
if not lazy:
|
if not lazy:
|
||||||
|
weights.clear()
|
||||||
|
if sequential_load:
|
||||||
|
for layer in model.layers:
|
||||||
|
mx.eval(layer.parameters())
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -712,6 +722,7 @@ def load(
|
|||||||
model_config={},
|
model_config={},
|
||||||
adapter_path: Optional[str] = None,
|
adapter_path: Optional[str] = None,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
|
sequential_load: bool = False,
|
||||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||||
"""
|
"""
|
||||||
Load the model and tokenizer from a given path or a huggingface repository.
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
@@ -727,6 +738,8 @@ def load(
|
|||||||
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
|
sequential_load (bool): If True then load each layer sequentially to
|
||||||
|
ensure that we are not wasting memory.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||||
|
|
||||||
@@ -736,7 +749,7 @@ def load(
|
|||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model, config = load_model(model_path, lazy)
|
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = load_adapters(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -750,7 +763,7 @@ def load(
|
|||||||
def fetch_from_hub(
|
def fetch_from_hub(
|
||||||
model_path: Path, lazy: bool = False
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
model, config = load_model(model_path, lazy)
|
model, config = load_model(model_path, lazy=lazy)
|
||||||
tokenizer = load_tokenizer(
|
tokenizer = load_tokenizer(
|
||||||
model_path, eos_token_ids=config.get("eos_token_id", None)
|
model_path, eos_token_ids=config.get("eos_token_id", None)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user