diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index abdd7c36..5bf17ef8 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -19,6 +19,7 @@ class GRPODataset:
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
answer_key: str = "answer",
+ system_key: str = "system",
use_chat_template: bool = False,
use_prompt: bool = False
):
@@ -27,9 +28,11 @@ class GRPODataset:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
if use_chat_template:
+ default_system_str = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."
+ system_str = item.get(system_key, default_system_str)
prompt_tokens = tokenizer.apply_chat_template(
[
- {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
+ {'role': 'system', 'content': system_str},
{'role': 'user', 'content': prompt_str}
],
add_generation_prompt=True
diff --git a/llms/mlx_lm/tuner/grpo_reward_functions.py b/llms/mlx_lm/tuner/grpo_reward_functions.py
new file mode 100644
index 00000000..59dfbfef
--- /dev/null
+++ b/llms/mlx_lm/tuner/grpo_reward_functions.py
@@ -0,0 +1,82 @@
+from typing import List, Optional, Callable
+import re
+
+
+RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
+
+
+def r1_extract_xml_answer(text: str) -> str:
+ try:
+ answer = text.split("")[-1]
+ answer = answer.split("")[0]
+ return answer.strip()
+ except:
+ print("r1_extract_xml_answer returned empty string")
+ return ""
+
+def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ if not completions:
+ return [0.0] * len(prompts)
+ extracted_responses = [r1_extract_xml_answer(r) for r in completions]
+ return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
+
+def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ if not completions or not answer:
+ return [0.0] * len(prompts)
+ extracted_responses = [r1_extract_xml_answer(r) for r in completions]
+ return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
+def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ if not completions:
+ return [0.0] * len(prompts)
+
+ scores = []
+ for completion in completions:
+ if not completion:
+ scores.append(0.0)
+ continue
+
+ reason_start = completion.find("")
+ reason_end = completion.find("")
+ answer_start = completion.find("")
+ answer_end = completion.find("")
+
+ if (reason_start != -1 and reason_end != -1 and
+ answer_start != -1 and answer_end != -1 and
+ reason_start < reason_end < answer_start < answer_end):
+ reason_content = completion[reason_start+13:reason_end].strip()
+ answer_content = completion[answer_start+8:answer_end].strip()
+ if reason_content and answer_content:
+ scores.append(0.5)
+ continue
+ scores.append(0.0)
+ return scores
+
+def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ if not completions:
+ return [0.0] * len(prompts)
+ pattern = r"\n.*?\n\n*?"
+ matches = [bool(re.search(pattern, r)) if r else False for r in completions]
+ return [0.5 if match else 0.0 for match in matches]
+
+def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ if not completions:
+ return [0.0] * len(prompts)
+ scores = []
+ for text in completions:
+ if not text:
+ scores.append(0.0)
+ continue
+ count = 0.0
+ if text.count("\n") == 1:
+ count += 0.125
+ if text.count("") == 1:
+ count += 0.125
+ if text.count("") == 1:
+ count += 0.125
+ if text.count("") == 1:
+ count += 0.125
+ end_text = text.split("")[-1]
+ count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
+ scores.append(max(0.0, count))
+ return scores
\ No newline at end of file
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 9d9051f0..5ec3020a 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -1,16 +1,16 @@
# Copyright © 2024 Apple Inc.
-from typing import List, Optional, Callable
+from typing import List, Optional, Tuple, Generator, Callable, Any
from dataclasses import dataclass, field
from pathlib import Path
import time
-import re
from mlx.utils import tree_flatten
import mlx.core as mx
import mlx.nn as nn
import numpy as np
+from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, RewardFunctions
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from ..utils import generate_step
from ..models import cache
@@ -50,88 +50,16 @@ class GRPOTrainingArgs(TrainingArgs):
)
-RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
-
-
-def r1_extract_xml_answer(text: str) -> str:
- try:
- answer = text.split("")[-1]
- answer = answer.split("")[0]
- return answer.strip()
- except:
- print("r1_extract_xml_answer returned empty string")
- return ""
-
-def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- if not completions:
- return [0.0] * len(prompts)
- extracted_responses = [r1_extract_xml_answer(r) for r in completions]
- return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
-
-def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- if not completions or not answer:
- return [0.0] * len(prompts)
- extracted_responses = [r1_extract_xml_answer(r) for r in completions]
- return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
-
-def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- if not completions:
- return [0.0] * len(prompts)
-
- scores = []
- for completion in completions:
- if not completion:
- scores.append(0.0)
- continue
-
- reason_start = completion.find("")
- reason_end = completion.find("")
- answer_start = completion.find("")
- answer_end = completion.find("")
-
- if (reason_start != -1 and reason_end != -1 and
- answer_start != -1 and answer_end != -1 and
- reason_start < reason_end < answer_start < answer_end):
- reason_content = completion[reason_start+13:reason_end].strip()
- answer_content = completion[answer_start+8:answer_end].strip()
- if reason_content and answer_content:
- scores.append(0.5)
- continue
- scores.append(0.0)
- return scores
-
-def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- if not completions:
- return [0.0] * len(prompts)
- pattern = r"\n.*?\n\n*?"
- matches = [bool(re.search(pattern, r)) if r else False for r in completions]
- return [0.5 if match else 0.0 for match in matches]
-
-
-def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- if not completions:
- return [0.0] * len(prompts)
- scores = []
- for text in completions:
- if not text:
- scores.append(0.0)
- continue
- count = 0.0
- if text.count("\n") == 1:
- count += 0.125
- if text.count("") == 1:
- count += 0.125
- if text.count("") == 1:
- count += 0.125
- if text.count("") == 1:
- count += 0.125
- end_text = text.split("")[-1]
- count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
- scores.append(max(0.0, count))
- return scores
-
-
-def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "", temperature: float = 0.8):
+def generate_grpo(
+ model: nn.Module,
+ prompts,
+ max_tokens,
+ tokenizer,
+ group_size,
+ is_training=False,
+ end_token: str = "",
+ temperature: float = 0.8
+ ):
if len(prompts.shape) == 1:
prompts = prompts[None, :]
if prompts.shape[1] == 0:
@@ -144,19 +72,38 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
try:
for idx in range(batch_size):
current_tokens = []
-
if is_training:
current_input = expanded_prompts[idx]
prompt_cache = cache.make_prompt_cache(model)
logits = model(current_input[None], cache=prompt_cache)[:, -1]
mx.eval(logits, prompt_cache)
-
while len(current_tokens) < max_tokens:
logits_temp = logits / temperature
probs = nn.softmax(logits_temp, axis=-1)
- next_token = mx.argmax(probs, axis=-1)
+ next_token = mx.random.categorical(logits_temp)
token = next_token.item()
-
+ test_sequence = current_tokens + [token]
+ if (len(test_sequence) >= len(end_sequence) and
+ mx.array_equal(
+ mx.array(test_sequence[-len(end_sequence):]),
+ end_sequence
+ )):
+ current_tokens.append(token)
+ break
+ if token == tokenizer.eos_token_id:
+ break
+ current_tokens.append(token)
+ current_input = mx.array([token])
+ logits = model(current_input[None], cache=prompt_cache)[:, -1]
+ mx.eval(current_input, logits, probs, next_token, token)
+ else:
+ generator = generate_step(
+ expanded_prompts[idx],
+ model,
+ max_tokens=max_tokens,
+ sampler=lambda x: mx.random.categorical(x / temperature)
+ )
+ for token, _ in generator:
test_sequence = current_tokens + [token]
if (len(test_sequence) >= len(end_sequence) and
mx.array_equal(
@@ -168,27 +115,10 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
if token == tokenizer.eos_token_id:
break
-
current_tokens.append(token)
- current_input = mx.array([token])
- logits = model(current_input[None], cache=prompt_cache)[:, -1]
- mx.eval(current_input, logits, probs, next_token)
- else:
- generator = generate_step(
- expanded_prompts[idx],
- model,
- max_tokens=max_tokens,
- sampler=lambda x: mx.argmax(x, axis=-1)
- )
- for token, _ in generator:
- if token == tokenizer.eos_token_id:
- break
- current_tokens.append(token)
-
if current_tokens:
results.append(mx.array(current_tokens))
mx.metal.clear_cache()
-
mx.eval(results)
return results
@@ -207,12 +137,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
seq_logits = logits[i, :seq_len]
seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1)
-
token_log_probs = mx.take_along_axis(
log_probs,
seq_targets.reshape(seq_len, 1), axis=-1
).squeeze(-1)
-
per_token_logps.append(token_log_probs)
mx.eval(logits)
return per_token_logps
@@ -399,7 +327,7 @@ def grpo_loss(
}
if is_validation:
- print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
+ print(f"\nValidation sample generation:\n{all_completion_texts}\n")
print(f"Validation sample answer:\n{answer_text[-1]}\n")
mx.metal.clear_cache()