mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
Merge branch 'ml-explore:main' into adding-dpo-training
This commit is contained in:
commit
4b44434c54
@ -295,7 +295,9 @@ class MLXLM(LM):
|
|||||||
completions = []
|
completions = []
|
||||||
|
|
||||||
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
||||||
context = self._tokenize(context)
|
context = self.tokenizer.encode(
|
||||||
|
context, add_special_tokens=not self.use_chat_template
|
||||||
|
)
|
||||||
max_tokens = min(
|
max_tokens = min(
|
||||||
self._max_tokens,
|
self._max_tokens,
|
||||||
self.tokenizer.model_max_length - len(context),
|
self.tokenizer.model_max_length - len(context),
|
||||||
|
@ -378,9 +378,11 @@ class DeepseekV2Model(nn.Module):
|
|||||||
# rank=pipeline_size-1 gets the first
|
# rank=pipeline_size-1 gets the first
|
||||||
self.pipeline_rank = group.rank()
|
self.pipeline_rank = group.rank()
|
||||||
self.pipeline_size = group.size()
|
self.pipeline_size = group.size()
|
||||||
layers_per_rank = (
|
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||||
len(self.layers) + self.pipeline_size - 1
|
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||||
) // self.pipeline_size
|
if self.pipeline_rank < extra:
|
||||||
|
layers_per_rank += 1
|
||||||
|
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
self.num_layers = layers_per_rank
|
self.num_layers = layers_per_rank
|
||||||
|
@ -410,9 +410,10 @@ class DeepseekV3Model(nn.Module):
|
|||||||
# rank=pipeline_size-1 gets the first
|
# rank=pipeline_size-1 gets the first
|
||||||
self.pipeline_rank = group.rank()
|
self.pipeline_rank = group.rank()
|
||||||
self.pipeline_size = group.size()
|
self.pipeline_size = group.size()
|
||||||
layers_per_rank = (
|
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||||
len(self.layers) + self.pipeline_size - 1
|
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||||
) // self.pipeline_size
|
if self.pipeline_rank < extra:
|
||||||
|
layers_per_rank += 1
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
self.layers = self.layers[: self.end_idx]
|
self.layers = self.layers[: self.end_idx]
|
||||||
|
@ -233,8 +233,8 @@ def train(
|
|||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
trained_tokens = 0
|
trained_tokens = 0
|
||||||
|
train_time = 0
|
||||||
# Main training loop
|
# Main training loop
|
||||||
start = time.perf_counter()
|
|
||||||
for it, batch in zip(
|
for it, batch in zip(
|
||||||
range(1, args.iters + 1),
|
range(1, args.iters + 1),
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
@ -245,10 +245,11 @@ def train(
|
|||||||
train=True,
|
train=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
tic = time.perf_counter()
|
||||||
# Report validation loss if needed, the first validation loss
|
# Report validation loss if needed, the first validation loss
|
||||||
# is always measured before any training.
|
# is always measured before any training.
|
||||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
val_loss = evaluate(
|
val_loss = evaluate(
|
||||||
model=model,
|
model=model,
|
||||||
dataset=val_dataset,
|
dataset=val_dataset,
|
||||||
@ -259,7 +260,7 @@ def train(
|
|||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
iterate_batches=iterate_batches,
|
iterate_batches=iterate_batches,
|
||||||
)
|
)
|
||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - tic
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: "
|
f"Iter {it}: "
|
||||||
@ -276,24 +277,23 @@ def train(
|
|||||||
}
|
}
|
||||||
training_callback.on_val_loss_report(val_info)
|
training_callback.on_val_loss_report(val_info)
|
||||||
|
|
||||||
start = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
|
|
||||||
lvalue, toks = step(batch)
|
lvalue, toks = step(batch)
|
||||||
losses += lvalue
|
losses += lvalue
|
||||||
n_tokens += toks
|
n_tokens += toks
|
||||||
steps += 1
|
steps += 1
|
||||||
mx.eval(state, losses, n_tokens)
|
mx.eval(state, losses, n_tokens)
|
||||||
|
train_time += time.perf_counter() - tic
|
||||||
|
|
||||||
# Report training loss if needed
|
# Report training loss if needed
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
|
||||||
|
|
||||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||||
train_loss /= steps * mx.distributed.init().size()
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / train_time
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / train_time
|
||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -322,7 +322,7 @@ def train(
|
|||||||
losses = 0
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
start = time.perf_counter()
|
train_time = 0
|
||||||
|
|
||||||
# Save adapter weights
|
# Save adapter weights
|
||||||
if it % args.steps_per_save == 0:
|
if it % args.steps_per_save == 0:
|
||||||
|
@ -89,6 +89,7 @@ def linear_to_lora_layers(
|
|||||||
"mixtral",
|
"mixtral",
|
||||||
"nemotron",
|
"nemotron",
|
||||||
"stablelm",
|
"stablelm",
|
||||||
|
"hunyuan",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
"phimoe",
|
"phimoe",
|
||||||
|
@ -13,7 +13,18 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -65,6 +76,7 @@ class GenerationResponse:
|
|||||||
Args:
|
Args:
|
||||||
text (str): The next segment of decoded text. This can be an empty string.
|
text (str): The next segment of decoded text. This can be an empty string.
|
||||||
token (int): The next token.
|
token (int): The next token.
|
||||||
|
from_draft (bool): Whether the token was generated by the draft model.
|
||||||
logprobs (mx.array): A vector of log probabilities.
|
logprobs (mx.array): A vector of log probabilities.
|
||||||
prompt_tokens (int): The number of tokens in the prompt.
|
prompt_tokens (int): The number of tokens in the prompt.
|
||||||
prompt_tps (float): The prompt processing tokens-per-second.
|
prompt_tps (float): The prompt processing tokens-per-second.
|
||||||
@ -77,6 +89,7 @@ class GenerationResponse:
|
|||||||
text: str
|
text: str
|
||||||
token: int
|
token: int
|
||||||
logprobs: mx.array
|
logprobs: mx.array
|
||||||
|
from_draft: bool
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
prompt_tps: float
|
prompt_tps: float
|
||||||
generation_tokens: int
|
generation_tokens: int
|
||||||
@ -338,7 +351,7 @@ def speculative_generate_step(
|
|||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
@ -365,7 +378,8 @@ def speculative_generate_step(
|
|||||||
when ``kv_bits`` is non-None. Default: ``0``.
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
|
||||||
|
and a bool indicating if the token was generated by the draft model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
@ -450,12 +464,12 @@ def speculative_generate_step(
|
|||||||
break
|
break
|
||||||
n += 1
|
n += 1
|
||||||
ntoks += 1
|
ntoks += 1
|
||||||
yield tn, lpn
|
yield tn, lpn, True
|
||||||
if ntoks == max_tokens:
|
if ntoks == max_tokens:
|
||||||
break
|
break
|
||||||
if ntoks < max_tokens:
|
if ntoks < max_tokens:
|
||||||
ntoks += 1
|
ntoks += 1
|
||||||
yield tokens[n], logprobs[n]
|
yield tokens[n], logprobs[n], False
|
||||||
|
|
||||||
if ntoks == max_tokens:
|
if ntoks == max_tokens:
|
||||||
break
|
break
|
||||||
@ -463,7 +477,7 @@ def speculative_generate_step(
|
|||||||
y = mx.array([tokens[n]], mx.uint32)
|
y = mx.array([tokens[n]], mx.uint32)
|
||||||
draft_y = y
|
draft_y = y
|
||||||
|
|
||||||
# If we accpeted all the draft tokens, include the last
|
# If we accepted all the draft tokens, include the last
|
||||||
# draft token in the next draft step since it hasn't been
|
# draft token in the next draft step since it hasn't been
|
||||||
# processed yet by the draft model
|
# processed yet by the draft model
|
||||||
if n == num_draft:
|
if n == num_draft:
|
||||||
@ -518,6 +532,10 @@ def stream_generate(
|
|||||||
if draft_model is None:
|
if draft_model is None:
|
||||||
kwargs.pop("num_draft_tokens", None)
|
kwargs.pop("num_draft_tokens", None)
|
||||||
token_generator = generate_step(prompt, model, **kwargs)
|
token_generator = generate_step(prompt, model, **kwargs)
|
||||||
|
# from_draft always false for non-speculative generation
|
||||||
|
token_generator = (
|
||||||
|
(token, logprobs, False) for token, logprobs in token_generator
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kwargs.pop("max_kv_size", None)
|
kwargs.pop("max_kv_size", None)
|
||||||
token_generator = speculative_generate_step(
|
token_generator = speculative_generate_step(
|
||||||
@ -526,7 +544,7 @@ def stream_generate(
|
|||||||
with wired_limit(model, [generation_stream]):
|
with wired_limit(model, [generation_stream]):
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for n, (token, logprobs) in enumerate(token_generator):
|
for n, (token, logprobs, from_draft) in enumerate(token_generator):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = prompt.size / prompt_time
|
||||||
@ -540,6 +558,7 @@ def stream_generate(
|
|||||||
text=detokenizer.last_segment,
|
text=detokenizer.last_segment,
|
||||||
token=token,
|
token=token,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
from_draft=from_draft,
|
||||||
prompt_tokens=prompt.size,
|
prompt_tokens=prompt.size,
|
||||||
prompt_tps=prompt_tps,
|
prompt_tps=prompt_tps,
|
||||||
generation_tokens=n + 1,
|
generation_tokens=n + 1,
|
||||||
@ -553,6 +572,7 @@ def stream_generate(
|
|||||||
text=detokenizer.last_segment,
|
text=detokenizer.last_segment,
|
||||||
token=token,
|
token=token,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
from_draft=from_draft,
|
||||||
prompt_tokens=prompt.size,
|
prompt_tokens=prompt.size,
|
||||||
prompt_tps=prompt_tps,
|
prompt_tps=prompt_tps,
|
||||||
generation_tokens=n + 1,
|
generation_tokens=n + 1,
|
||||||
|
@ -1,17 +1,24 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from mlx_lm.sample_utils import make_logits_processors
|
from mlx_lm.sample_utils import make_logits_processors
|
||||||
from mlx_lm.utils import generate, load
|
from mlx_lm.utils import (
|
||||||
|
GenerationResponse,
|
||||||
|
generate,
|
||||||
|
load,
|
||||||
|
make_sampler,
|
||||||
|
stream_generate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestGenerate(unittest.TestCase):
|
class TestGenerate(unittest.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
|
||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self):
|
||||||
# Simple test that generation runs
|
# Simple test that generation runs
|
||||||
@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||||
|
|
||||||
|
def test_stream_generate_speculative(self):
|
||||||
|
# Use same model as draft model, this is not a speed test
|
||||||
|
draft_model, _ = load(self.HF_MODEL_PATH)
|
||||||
|
|
||||||
|
results: List[GenerationResponse] = []
|
||||||
|
drafted: List[bool] = []
|
||||||
|
|
||||||
|
# make a determinate sampler
|
||||||
|
sampler = make_sampler(temp=0.0)
|
||||||
|
|
||||||
|
for generation_result in stream_generate(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
prompt="hello",
|
||||||
|
max_tokens=5,
|
||||||
|
draft_model=draft_model,
|
||||||
|
num_draft_tokens=2,
|
||||||
|
sampler=sampler,
|
||||||
|
):
|
||||||
|
drafted.append(generation_result.from_draft)
|
||||||
|
results.append(generation_result)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), 5)
|
||||||
|
# since num_draft_tokens is 2 and draft model is the same, the
|
||||||
|
# first 2 generations should be drafts, the third should come
|
||||||
|
# from the target model, and last two should be drafts
|
||||||
|
self.assertEqual(drafted, [True, True, False, True, True])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -195,6 +195,8 @@ def transcribe(
|
|||||||
seek_points.append(0)
|
seek_points.append(0)
|
||||||
if len(seek_points) % 2 == 1:
|
if len(seek_points) % 2 == 1:
|
||||||
seek_points.append(content_frames)
|
seek_points.append(content_frames)
|
||||||
|
else:
|
||||||
|
seek_points[-1] = min(content_frames, seek_points[-1])
|
||||||
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||||
|
|
||||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||||
|
Loading…
Reference in New Issue
Block a user