Merge branch 'ml-explore:main' into adding-dpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-12 11:09:58 +01:00 committed by GitHub
commit 4b44434c54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 89 additions and 26 deletions

View File

@ -295,7 +295,9 @@ class MLXLM(LM):
completions = [] completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)): for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
context = self._tokenize(context) context = self.tokenizer.encode(
context, add_special_tokens=not self.use_chat_template
)
max_tokens = min( max_tokens = min(
self._max_tokens, self._max_tokens,
self.tokenizer.model_max_length - len(context), self.tokenizer.model_max_length - len(context),

View File

@ -378,9 +378,11 @@ class DeepseekV2Model(nn.Module):
# rank=pipeline_size-1 gets the first # rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank() self.pipeline_rank = group.rank()
self.pipeline_size = group.size() self.pipeline_size = group.size()
layers_per_rank = ( layers_per_rank = len(self.layers) // self.pipeline_size
len(self.layers) + self.pipeline_size - 1 extra = len(self.layers) - layers_per_rank * self.pipeline_size
) // self.pipeline_size if self.pipeline_rank < extra:
layers_per_rank += 1
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank self.num_layers = layers_per_rank

View File

@ -410,9 +410,10 @@ class DeepseekV3Model(nn.Module):
# rank=pipeline_size-1 gets the first # rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank() self.pipeline_rank = group.rank()
self.pipeline_size = group.size() self.pipeline_size = group.size()
layers_per_rank = ( layers_per_rank = len(self.layers) // self.pipeline_size
len(self.layers) + self.pipeline_size - 1 extra = len(self.layers) - layers_per_rank * self.pipeline_size
) // self.pipeline_size if self.pipeline_rank < extra:
layers_per_rank += 1
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank self.end_idx = self.start_idx + layers_per_rank
self.layers = self.layers[: self.end_idx] self.layers = self.layers[: self.end_idx]

View File

@ -233,8 +233,8 @@ def train(
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
trained_tokens = 0 trained_tokens = 0
train_time = 0
# Main training loop # Main training loop
start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
range(1, args.iters + 1), range(1, args.iters + 1),
iterate_batches( iterate_batches(
@ -245,10 +245,11 @@ def train(
train=True, train=True,
), ),
): ):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss # Report validation loss if needed, the first validation loss
# is always measured before any training. # is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter() tic = time.perf_counter()
val_loss = evaluate( val_loss = evaluate(
model=model, model=model,
dataset=val_dataset, dataset=val_dataset,
@ -259,7 +260,7 @@ def train(
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - tic
if rank == 0: if rank == 0:
print( print(
f"Iter {it}: " f"Iter {it}: "
@ -276,24 +277,23 @@ def train(
} }
training_callback.on_val_loss_report(val_info) training_callback.on_val_loss_report(val_info)
start = time.perf_counter() tic = time.perf_counter()
lvalue, toks = step(batch) lvalue, toks = step(batch)
losses += lvalue losses += lvalue
n_tokens += toks n_tokens += toks
steps += 1 steps += 1
mx.eval(state, losses, n_tokens) mx.eval(state, losses, n_tokens)
train_time += time.perf_counter() - tic
# Report training loss if needed # Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size() train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / train_time
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9 peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0: if rank == 0:
@ -322,7 +322,7 @@ def train(
losses = 0 losses = 0
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
start = time.perf_counter() train_time = 0
# Save adapter weights # Save adapter weights
if it % args.steps_per_save == 0: if it % args.steps_per_save == 0:

View File

@ -89,6 +89,7 @@ def linear_to_lora_layers(
"mixtral", "mixtral",
"nemotron", "nemotron",
"stablelm", "stablelm",
"hunyuan",
"qwen2", "qwen2",
"qwen2_moe", "qwen2_moe",
"phimoe", "phimoe",

View File

@ -13,7 +13,18 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union from typing import (
Any,
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
)
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -65,6 +76,7 @@ class GenerationResponse:
Args: Args:
text (str): The next segment of decoded text. This can be an empty string. text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token. token (int): The next token.
from_draft (bool): Whether the token was generated by the draft model.
logprobs (mx.array): A vector of log probabilities. logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt. prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second. prompt_tps (float): The prompt processing tokens-per-second.
@ -77,6 +89,7 @@ class GenerationResponse:
text: str text: str
token: int token: int
logprobs: mx.array logprobs: mx.array
from_draft: bool
prompt_tokens: int prompt_tokens: int
prompt_tps: float prompt_tps: float
generation_tokens: int generation_tokens: int
@ -338,7 +351,7 @@ def speculative_generate_step(
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -365,7 +378,8 @@ def speculative_generate_step(
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
and a bool indicating if the token was generated by the draft model
""" """
y = prompt y = prompt
@ -450,12 +464,12 @@ def speculative_generate_step(
break break
n += 1 n += 1
ntoks += 1 ntoks += 1
yield tn, lpn yield tn, lpn, True
if ntoks == max_tokens: if ntoks == max_tokens:
break break
if ntoks < max_tokens: if ntoks < max_tokens:
ntoks += 1 ntoks += 1
yield tokens[n], logprobs[n] yield tokens[n], logprobs[n], False
if ntoks == max_tokens: if ntoks == max_tokens:
break break
@ -463,7 +477,7 @@ def speculative_generate_step(
y = mx.array([tokens[n]], mx.uint32) y = mx.array([tokens[n]], mx.uint32)
draft_y = y draft_y = y
# If we accpeted all the draft tokens, include the last # If we accepted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been # draft token in the next draft step since it hasn't been
# processed yet by the draft model # processed yet by the draft model
if n == num_draft: if n == num_draft:
@ -518,6 +532,10 @@ def stream_generate(
if draft_model is None: if draft_model is None:
kwargs.pop("num_draft_tokens", None) kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs) token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
else: else:
kwargs.pop("max_kv_size", None) kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step( token_generator = speculative_generate_step(
@ -526,7 +544,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
tic = time.perf_counter() tic = time.perf_counter()
for n, (token, logprobs) in enumerate(token_generator): for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
@ -540,6 +558,7 @@ def stream_generate(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
logprobs=logprobs, logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,
generation_tokens=n + 1, generation_tokens=n + 1,
@ -553,6 +572,7 @@ def stream_generate(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
logprobs=logprobs, logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,
generation_tokens=n + 1, generation_tokens=n + 1,

View File

@ -1,17 +1,24 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import unittest import unittest
from typing import List
from mlx_lm.sample_utils import make_logits_processors from mlx_lm.sample_utils import make_logits_processors
from mlx_lm.utils import generate, load from mlx_lm.utils import (
GenerationResponse,
generate,
load,
make_sampler,
stream_generate,
)
class TestGenerate(unittest.TestCase): class TestGenerate(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH) cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
def test_generate(self): def test_generate(self):
# Simple test that generation runs # Simple test that generation runs
@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
) )
self.assertEqual(len(all_toks), len(init_toks) + 5) self.assertEqual(len(all_toks), len(init_toks) + 5)
def test_stream_generate_speculative(self):
# Use same model as draft model, this is not a speed test
draft_model, _ = load(self.HF_MODEL_PATH)
results: List[GenerationResponse] = []
drafted: List[bool] = []
# make a determinate sampler
sampler = make_sampler(temp=0.0)
for generation_result in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt="hello",
max_tokens=5,
draft_model=draft_model,
num_draft_tokens=2,
sampler=sampler,
):
drafted.append(generation_result.from_draft)
results.append(generation_result)
self.assertEqual(len(results), 5)
# since num_draft_tokens is 2 and draft model is the same, the
# first 2 generations should be drafts, the third should come
# from the target model, and last two should be drafts
self.assertEqual(drafted, [True, True, False, True, True])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -195,6 +195,8 @@ def transcribe(
seek_points.append(0) seek_points.append(0)
if len(seek_points) % 2 == 1: if len(seek_points) % 2 == 1:
seek_points.append(content_frames) seek_points.append(content_frames)
else:
seek_points[-1] = min(content_frames, seek_points[-1])
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、" punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"