Merge branch 'ml-explore:main' into adding-GRPO-training

This commit is contained in:
Gökdeniz Gülmez
2025-02-12 11:10:10 +01:00
committed by GitHub
8 changed files with 89 additions and 26 deletions

View File

@@ -295,7 +295,9 @@ class MLXLM(LM):
completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
context = self._tokenize(context)
context = self.tokenizer.encode(
context, add_special_tokens=not self.use_chat_template
)
max_tokens = min(
self._max_tokens,
self.tokenizer.model_max_length - len(context),

View File

@@ -378,9 +378,11 @@ class DeepseekV2Model(nn.Module):
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
layers_per_rank = len(self.layers) // self.pipeline_size
extra = len(self.layers) - layers_per_rank * self.pipeline_size
if self.pipeline_rank < extra:
layers_per_rank += 1
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank

View File

@@ -410,9 +410,10 @@ class DeepseekV3Model(nn.Module):
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
layers_per_rank = len(self.layers) // self.pipeline_size
extra = len(self.layers) - layers_per_rank * self.pipeline_size
if self.pipeline_rank < extra:
layers_per_rank += 1
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.layers = self.layers[: self.end_idx]

View File

@@ -233,8 +233,8 @@ def train(
n_tokens = 0
steps = 0
trained_tokens = 0
train_time = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_batches(
@@ -245,10 +245,11 @@ def train(
train=True,
),
):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
tic = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
@@ -259,7 +260,7 @@ def train(
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
val_time = time.perf_counter() - tic
if rank == 0:
print(
f"Iter {it}: "
@@ -276,24 +277,23 @@ def train(
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
tic = time.perf_counter()
lvalue, toks = step(batch)
losses += lvalue
n_tokens += toks
steps += 1
mx.eval(state, losses, n_tokens)
train_time += time.perf_counter() - tic
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / train_time
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
@@ -322,7 +322,7 @@ def train(
losses = 0
n_tokens = 0
steps = 0
start = time.perf_counter()
train_time = 0
# Save adapter weights
if it % args.steps_per_save == 0:

View File

@@ -89,6 +89,7 @@ def linear_to_lora_layers(
"mixtral",
"nemotron",
"stablelm",
"hunyuan",
"qwen2",
"qwen2_moe",
"phimoe",

View File

@@ -13,7 +13,18 @@ import time
from dataclasses import dataclass
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
)
import mlx.core as mx
import mlx.nn as nn
@@ -65,6 +76,7 @@ class GenerationResponse:
Args:
text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token.
from_draft (bool): Whether the token was generated by the draft model.
logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second.
@@ -77,6 +89,7 @@ class GenerationResponse:
text: str
token: int
logprobs: mx.array
from_draft: bool
prompt_tokens: int
prompt_tps: float
generation_tokens: int
@@ -338,7 +351,7 @@ def speculative_generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -365,7 +378,8 @@ def speculative_generate_step(
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
and a bool indicating if the token was generated by the draft model
"""
y = prompt
@@ -450,12 +464,12 @@ def speculative_generate_step(
break
n += 1
ntoks += 1
yield tn, lpn
yield tn, lpn, True
if ntoks == max_tokens:
break
if ntoks < max_tokens:
ntoks += 1
yield tokens[n], logprobs[n]
yield tokens[n], logprobs[n], False
if ntoks == max_tokens:
break
@@ -463,7 +477,7 @@ def speculative_generate_step(
y = mx.array([tokens[n]], mx.uint32)
draft_y = y
# If we accpeted all the draft tokens, include the last
# If we accepted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been
# processed yet by the draft model
if n == num_draft:
@@ -518,6 +532,10 @@ def stream_generate(
if draft_model is None:
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
else:
kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step(
@@ -526,7 +544,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]):
detokenizer.reset()
tic = time.perf_counter()
for n, (token, logprobs) in enumerate(token_generator):
for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
@@ -540,6 +558,7 @@ def stream_generate(
text=detokenizer.last_segment,
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
@@ -553,6 +572,7 @@ def stream_generate(
text=detokenizer.last_segment,
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,