Merge branch 'ml-explore:main' into adding-dpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-04 10:45:50 +01:00 committed by GitHub
commit 9b489a6c0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 208 additions and 94 deletions

View File

@ -4,10 +4,11 @@
Run with: Run with:
``` ```
/path/to/mpirun \ mlx.launch \
-np 2 \
--hostfile /path/to/hosts.txt \ --hostfile /path/to/hosts.txt \
python /path/to/pipeline_generate.py --prompt "hello world" --backend mpi \
/path/to/pipeline_generate.py \
--prompt "hello world"
``` ```
Make sure you can run MLX over MPI on two hosts. For more information see the Make sure you can run MLX over MPI on two hosts. For more information see the
@ -17,10 +18,64 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
""" """
import argparse import argparse
import json
from pathlib import Path
import mlx.core as mx import mlx.core as mx
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from mlx_lm import load, stream_generate from mlx_lm import load, stream_generate
from mlx_lm.utils import load_model, load_tokenizer
def download(repo: str, allow_patterns: list[str]) -> Path:
return Path(
snapshot_download(
repo,
allow_patterns=allow_patterns,
)
)
def shard_and_load(repo):
# Get model path with everything but weight safetensors
model_path = download(
args.model,
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
)
# Lazy load and shard model to figure out
# which weights we need
model, _ = load_model(model_path, lazy=True, strict=False)
group = mx.distributed.init(backend="mpi")
rank = group.rank()
model.model.pipeline(group)
# Figure out which files we need for the local shard
with open(model_path / "model.safetensors.index.json", "r") as fid:
weight_index = json.load(fid)["weight_map"]
local_files = set()
for k, _ in tree_flatten(model.parameters()):
local_files.add(weight_index[k])
# Download weights for local shard
download(args.model, allow_patterns=local_files)
# Load and shard the model, and load the weights
tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path, lazy=True, strict=False)
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LLM pipelined inference example") parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument( parser.add_argument(
"--model", "--model",
@ -42,27 +97,21 @@ parser.add_argument(
) )
args = parser.parse_args() args = parser.parse_args()
model, tokenizer = load(args.model, lazy=True) group = mx.distributed.init(backend="mpi")
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
group = mx.distributed.init()
rank = group.rank() rank = group.rank()
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
def rprint(*args, **kwargs): def rprint(*args, **kwargs):
if rank == 0: if rank == 0:
print(*args, **kwargs) print(*args, **kwargs)
model, tokenizer = shard_and_load(args.model)
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model, tokenizer, prompt, max_tokens=args.max_tokens
):
rprint(response.text, end="", flush=True) rprint(response.text, end="", flush=True)
rprint() rprint()

View File

@ -364,8 +364,30 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer(config, idx) DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
] ]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__( def __call__(
self, self,
x: mx.array, x: mx.array,
@ -374,14 +396,31 @@ class DeepseekV2Model(nn.Module):
) -> mx.array: ) -> mx.array:
h = self.embed_tokens(x) h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None: if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * self.num_layers
for layer, c in zip(self.layers, cache): # Receive from the previous process in the pipeline
h = layer(h, mask, c) if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h) return self.norm(h)
@ -418,4 +457,4 @@ class Model(nn.Module):
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@ -381,6 +381,10 @@ class DeepseekV3Model(nn.Module):
DeepseekV3DecoderLayer(config, idx) DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
] ]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0 self.pipeline_rank = 0
self.pipeline_size = 1 self.pipeline_size = 1
@ -393,8 +397,11 @@ class DeepseekV3Model(nn.Module):
layers_per_rank = ( layers_per_rank = (
len(self.layers) + self.pipeline_size - 1 len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size ) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.layers = self.layers[start : start + layers_per_rank] self.end_idx = self.start_idx + layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__( def __call__(
self, self,
@ -412,15 +419,15 @@ class DeepseekV3Model(nn.Module):
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * self.num_layers
# Receive from the previous process in the pipeline # Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1: if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for layer, c in zip(self.layers, cache): for i in range(self.num_layers):
h = layer(h, mask, c) h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline # Send to the next process in the pipeline
if pipeline_rank != 0: if pipeline_rank != 0:
@ -468,4 +475,4 @@ class Model(nn.Module):
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@ -1,3 +1,5 @@
# Copyright © 2025 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple

View File

@ -1,4 +1,4 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024-2025 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
self.intermediate_size, self.hidden_size, bias=args.use_bias self.intermediate_size, self.hidden_size, bias=args.use_bias
) )
def ssm_step(self, x, state=None): def ssm_step(self, x, A, state=None):
A = -mx.exp(self.A_log)
D = self.D D = self.D
deltaBC = self.x_proj(x) deltaBC = self.x_proj(x)
delta, B, C = mx.split( delta, B, C = map(
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
mx.split(
deltaBC, deltaBC,
indices_or_sections=[ [self.time_step_rank, self.time_step_rank + self.ssm_state_size],
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1, axis=-1,
),
) )
if self.use_bcdt_rms: if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C)) delta, B, C = map(self.mixer_norm, (delta, B, C))
@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
y = y + D * x y = y + D * x
return y, new_state return y, new_state
def __call__(self, x, cache): def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape B, T, D = x.shape
if cache is None: xz = self.in_proj(x)
cache = [None, None] x, z = xz.split(indices_or_sections=2, axis=-1)
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
x = nn.silu(conv_out)
A = -mx.exp(self.A_log)
outputs = [] outputs = []
current_state = state_cache
y = []
for t in range(T): for t in range(T):
xt = x[:, t, :] y_t, current_state = self.ssm_step(x[:, t], A, current_state)
xz = self.in_proj(xt) y.append(y_t)
x_t, z_t = xz.split(indices_or_sections=2, axis=1) y = mx.stack(y, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) z = self.out_proj(nn.silu(z) * y)
x_t = conv_out.squeeze(1) return z, (new_conv_cache, current_state)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1]) def __call__(self, x, cache):
z_t = nn.silu(z_t) if cache is None:
output_t = y_t * z_t conv_cache, state_cache = None, None
output_t = self.out_proj(output_t) else:
outputs.append(output_t) conv_cache, state_cache = cache[0], cache[1]
output = mx.stack(outputs, axis=1)
output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)
if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache
return output return output

View File

@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2025 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union

View File

@ -140,8 +140,8 @@ def evaluate(
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_batches,
): ):
all_losses = 0 all_losses = mx.array(0.0)
ntokens = 0 ntokens = mx.array(0)
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)

View File

@ -627,6 +627,7 @@ def load_config(model_path: Path) -> dict:
def load_model( def load_model(
model_path: Path, model_path: Path,
lazy: bool = False, lazy: bool = False,
strict: bool = True,
model_config: dict = {}, model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module: ) -> nn.Module:
@ -638,6 +639,8 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
strict (bool): Whether or not to raise an exception if weights don't
match. Default: ``True``
model_config (dict, optional): Optional configuration parameters for the model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary. model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
@ -660,7 +663,7 @@ def load_model(
# Try weight for back-compat # Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors")) weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files: if not weight_files and strict:
logging.error(f"No safetensors found in {model_path}") logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}")
@ -694,7 +697,7 @@ def load_model(
class_predicate=class_predicate, class_predicate=class_predicate,
) )
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()), strict=strict)
if not lazy: if not lazy:
mx.eval(model.parameters()) mx.eval(model.parameters())