Merge branch 'ml-explore:main' into adding-dpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-04 10:45:50 +01:00 committed by GitHub
commit 9b489a6c0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 208 additions and 94 deletions

View File

@ -4,10 +4,11 @@
Run with:
```
/path/to/mpirun \
-np 2 \
mlx.launch \
--hostfile /path/to/hosts.txt \
python /path/to/pipeline_generate.py --prompt "hello world"
--backend mpi \
/path/to/pipeline_generate.py \
--prompt "hello world"
```
Make sure you can run MLX over MPI on two hosts. For more information see the
@ -17,62 +18,110 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
"""
import argparse
import json
from pathlib import Path
import mlx.core as mx
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from mlx_lm import load, stream_generate
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
model, tokenizer = load(args.model, lazy=True)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
group = mx.distributed.init()
rank = group.rank()
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
from mlx_lm.utils import load_model, load_tokenizer
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
def download(repo: str, allow_patterns: list[str]) -> Path:
return Path(
snapshot_download(
repo,
allow_patterns=allow_patterns,
)
)
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
rprint(response.text, end="", flush=True)
def shard_and_load(repo):
# Get model path with everything but weight safetensors
model_path = download(
args.model,
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
# Lazy load and shard model to figure out
# which weights we need
model, _ = load_model(model_path, lazy=True, strict=False)
group = mx.distributed.init(backend="mpi")
rank = group.rank()
model.model.pipeline(group)
# Figure out which files we need for the local shard
with open(model_path / "model.safetensors.index.json", "r") as fid:
weight_index = json.load(fid)["weight_map"]
local_files = set()
for k, _ in tree_flatten(model.parameters()):
local_files.add(weight_index[k])
# Download weights for local shard
download(args.model, allow_patterns=local_files)
# Load and shard the model, and load the weights
tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path, lazy=True, strict=False)
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
group = mx.distributed.init(backend="mpi")
rank = group.rank()
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
model, tokenizer = shard_and_load(args.model)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model, tokenizer, prompt, max_tokens=args.max_tokens
):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@ -364,8 +364,30 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__(
self,
x: mx.array,
@ -374,14 +396,31 @@ class DeepseekV2Model(nn.Module):
) -> mx.array:
h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
cache = [None] * self.num_layers
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h)
@ -418,4 +457,4 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@ -381,6 +381,10 @@ class DeepseekV3Model(nn.Module):
DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
@ -393,8 +397,11 @@ class DeepseekV3Model(nn.Module):
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.layers = self.layers[start : start + layers_per_rank]
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__(
self,
@ -412,15 +419,15 @@ class DeepseekV3Model(nn.Module):
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
cache = [None] * self.num_layers
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
@ -468,4 +475,4 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@ -1,3 +1,5 @@
# Copyright © 2025 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple

View File

@ -1,4 +1,4 @@
# Copyright © 2024 Apple Inc.
# Copyright © 2024-2025 Apple Inc.
import math
from dataclasses import dataclass
@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
self.intermediate_size, self.hidden_size, bias=args.use_bias
)
def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
def ssm_step(self, x, A, state=None):
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
delta, B, C = map(
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
mx.split(
deltaBC,
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
axis=-1,
),
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
y = y + D * x
return y, new_state
def __call__(self, x, cache):
def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
xz = self.in_proj(x)
x, z = xz.split(indices_or_sections=2, axis=-1)
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
x = nn.silu(conv_out)
A = -mx.exp(self.A_log)
outputs = []
current_state = state_cache
y = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
y.append(y_t)
y = mx.stack(y, axis=1)
z = self.out_proj(nn.silu(z) * y)
return z, (new_conv_cache, current_state)
def __call__(self, x, cache):
if cache is None:
conv_cache, state_cache = None, None
else:
conv_cache, state_cache = cache[0], cache[1]
output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)
if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache
return output

View File

@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023-2025 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

View File

@ -140,8 +140,8 @@ def evaluate(
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
):
all_losses = 0
ntokens = 0
all_losses = mx.array(0.0)
ntokens = mx.array(0)
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)

View File

@ -627,6 +627,7 @@ def load_config(model_path: Path) -> dict:
def load_model(
model_path: Path,
lazy: bool = False,
strict: bool = True,
model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
@ -638,6 +639,8 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
strict (bool): Whether or not to raise an exception if weights don't
match. Default: ``True``
model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
@ -660,7 +663,7 @@ def load_model(
# Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files:
if not weight_files and strict:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
@ -694,7 +697,7 @@ def load_model(
class_predicate=class_predicate,
)
model.load_weights(list(weights.items()))
model.load_weights(list(weights.items()), strict=strict)
if not lazy:
mx.eval(model.parameters())