diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index 1f684d27..1e4fe3d5 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -63,11 +63,16 @@ CONFIG_DEFAULTS = {
"config": None,
"grad_checkpoint": False,
"lr_schedule": None,
+ "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
+
+ # GRPO args
"reference_model_path": None,
"group_size": 4,
"beta": 0.1,
"epsilon": 1e-4,
- "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
+ "max_completion_length": 512,
+ "use_chat_template": False,
+ "use_prompt": False,
}
@@ -178,9 +183,15 @@ def build_parser():
parser.add_argument(
"--group-size",
type=int,
- help="Number of responses per prompt.",
+ help="Number of generations.",
default=4,
)
+ parser.add_argument(
+ "--max-completion-length",
+ type=int,
+ help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
+ default=512,
+ )
parser.add_argument(
"--beta",
type=float,
@@ -193,6 +204,18 @@ def build_parser():
help="The Epsilon for numerical stability.",
default=1e-4,
)
+ parser.add_argument(
+ "--use-chat-template",
+ type=bool,
+ help="If the model is a Chat model, use the Chat template.",
+ default=False,
+ )
+ parser.add_argument(
+ "--use-prompt",
+ type=bool,
+ help="Rather to use the prompt from teh R1 paper.",
+ default=False,
+ )
return parser
@@ -262,6 +285,7 @@ def train_model(
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
+ max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
@@ -273,7 +297,7 @@ def train_model(
reference_model, _ = load(args.reference_model_path)
reference_model = reference_model.freeze()
else:
- reference_model, _ = None, None
+ reference_model, _ = load(args.model)
train_grpo(
model=model,
diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index 163b1be7..b31656c6 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -16,14 +16,33 @@ class GRPODataset:
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
- answer_key: str = "answer"
+ answer_key: str = "answer",
+ use_chat_template: bool = False,
+ use_prompt: bool = False
):
self._data = []
for item in data:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
- prompt_tokens = tokenizer.encode(prompt_str)
- answer_tokens = tokenizer.encode(answer_str)
+ if use_chat_template:
+ prompt_tokens = tokenizer.apply_chat_template(
+ [
+ {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
+ The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
+ The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
+ {'role': 'user', 'content': prompt_str}
+ ],
+ )
+ answer_tokens = tokenizer.encode(answer_str)
+ else:
+ if use_prompt:
+ prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
+ The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
+ The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .
+ User: {prompt_str}. Assistant: """)
+ else:
+ prompt_tokens = tokenizer.encode(prompt_str)
+ answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 29518d8f..3b5dc2e9 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -12,8 +12,6 @@ from mlx.utils import tree_flatten
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
-from mlx_lm.utils import generate_step
-
@dataclass
class GRPOTrainingArgs(TrainingArgs):
@@ -27,6 +25,9 @@ class GRPOTrainingArgs(TrainingArgs):
epsilon: float = field(
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
)
+ max_completion_length: int = field(
+ default=512, metadata={"help": "Number of Generations."}
+ )
reference_model_path: str = field(
default=None,
metadata={
@@ -36,7 +37,6 @@ class GRPOTrainingArgs(TrainingArgs):
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
- model.eval()
if len(prompt.shape) == 1:
prompt = prompt[None, :]
@@ -58,11 +58,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
token_value = next_token.item()
generated.append(next_token)
-
- # Clear intermediate tensors
- del logits, token_logits, probs
- mx.metal.clear_cache()
-
+
current_prompt = mx.concatenate([current_prompt, next_token[None]])
if token_value == tokenizer.eos_token_id:
break
@@ -72,12 +68,6 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
result = mx.concatenate([prompt[0], mx.stack(generated)])
mx.eval(result)
- model.train()
-
- # Clear generated tokens
- del generated
- mx.metal.clear_cache()
-
return result
@@ -192,11 +182,6 @@ def get_per_token_logps(model, inputs, lengths):
).squeeze(-1) # [seq_len]
per_token_logps.append(token_log_probs)
-
- # Clean up intermediates
- del seq_logits, seq_targets, log_probs, token_log_probs
- mx.metal.clear_cache()
-
return per_token_logps
@@ -232,15 +217,9 @@ def grpo_loss(
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
- del completion_ids
- mx.metal.clear_cache()
-
except Exception as e:
print(f"Generation error: {e}")
continue
-
- del prompt_tensor
- mx.metal.clear_cache()
# Prepare inputs
expanded_answers = []
@@ -264,25 +243,11 @@ def grpo_loss(
mask = mx.ones_like(completion_ids)
padded_completions.append(padded_ids)
attention_masks.append(mask)
-
- del completion_ids
- if padding_length > 0:
- del padding
- del mask
- mx.metal.clear_cache()
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
- del padded_completions, attention_masks
- mx.metal.clear_cache()
-
- # Get logits and compute log probabilities
- logits = model(inputs).astype(mx.float32)
- log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
- targets = inputs[:, 1:]
-
# Current policy probabilities
token_log_probs = get_per_token_logps(model, inputs, lengths)
@@ -302,9 +267,6 @@ def grpo_loss(
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
-
- del padding
- mx.metal.clear_cache()
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
@@ -360,10 +322,6 @@ def grpo_loss(
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
- # Clean up
- del all_completions
- mx.metal.clear_cache()
-
metrics = {
'total_rewards_mean': mx.mean(rewards),
'total_rewards_std': mx.std(rewards),
@@ -440,7 +398,7 @@ def evaluate_grpo(
group_size: int,
max_seq_length,
reward_funcs = None,
- loss: callable = grpo_loss,
+ loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
"""
@@ -466,7 +424,7 @@ def evaluate_grpo(
),
):
# Calculate loss for current batch
- losses, toks, metrics = loss(
+ losses, toks, metrics = loss_fn(
model=model,
tokenizer=tokenizer,
batch=batch,
@@ -518,7 +476,7 @@ def train_grpo(
r1_count_xml
],
args: GRPOTrainingArgs = GRPOTrainingArgs(),
- loss: callable = grpo_loss,
+ loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches,
training_callback: TrainingCallback = None,
):
@@ -546,7 +504,7 @@ def train_grpo(
group_size=args.group_size,
epsilon=args.epsilon,
ref_model=ref_model,
- max_tokens=args.max_seq_length,
+ max_tokens=args.max_completion_length,
)
# All reduce the gradients if running in distributed mode
@@ -557,22 +515,23 @@ def train_grpo(
return loss, toks, metrics
- loss_value_and_grad = nn.value_and_grad(model, loss)
+ loss_value_and_grad = nn.value_and_grad(model, loss_fn)
losses = 0
n_tokens = 0
steps = 0
trained_tokens = 0
accumulated_metrics = {
- 'rewards': 0,
- 'rewards_std': 0,
- 'grouped_rewards': 0,
+ 'total_rewards_mean': 0,
+ 'total_rewards_std': 0,
+ 'grouped_rewards_mean': 0,
'grouped_rewards_std': 0,
'kl': 0
}
- for i in range(len(reward_funcs)):
- accumulated_metrics[f'reward_func_{i}_mean'] = 0
- accumulated_metrics[f'reward_func_{i}_std'] = 0
+ for reward_func in reward_funcs:
+ func_name = reward_func.__name__
+ accumulated_metrics[f'{func_name}_mean'] = 0
+ accumulated_metrics[f'{func_name}_std'] = 0
start = time.perf_counter()
for it, batch in zip(
@@ -592,7 +551,7 @@ def train_grpo(
val_loss, val_ntokens, val_metrics = evaluate_grpo(
model=model,
dataset=val_dataset,
- loss=loss,
+ loss_fn=loss_fn,
ref_model=ref_model,
reward_funcs=reward_funcs,
tokenizer=tokenizer,
@@ -675,8 +634,8 @@ def train_grpo(
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
train_metrics_str += (
- f", Reward func {reward_func.__name__} mean {avg_metrics[f'reward_func_{reward_func.__name__}_mean']:.3f}, "
- f"Reward func {reward_func.__name__} std {avg_metrics[f'reward_func_{reward_func.__name__}_std']:.3f}"
+ f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
+ f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
)
print(