mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
first succesfull training run
This commit is contained in:
@@ -12,8 +12,6 @@ from mlx.utils import tree_flatten
|
||||
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||
|
||||
from mlx_lm.utils import generate_step
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOTrainingArgs(TrainingArgs):
|
||||
@@ -27,6 +25,9 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
epsilon: float = field(
|
||||
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
|
||||
)
|
||||
max_completion_length: int = field(
|
||||
default=512, metadata={"help": "Number of Generations."}
|
||||
)
|
||||
reference_model_path: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@@ -36,7 +37,6 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
|
||||
|
||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
||||
model.eval()
|
||||
if len(prompt.shape) == 1:
|
||||
prompt = prompt[None, :]
|
||||
|
||||
@@ -58,11 +58,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
||||
|
||||
token_value = next_token.item()
|
||||
generated.append(next_token)
|
||||
|
||||
# Clear intermediate tensors
|
||||
del logits, token_logits, probs
|
||||
mx.metal.clear_cache()
|
||||
|
||||
|
||||
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
break
|
||||
@@ -72,12 +68,6 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
||||
|
||||
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
||||
mx.eval(result)
|
||||
model.train()
|
||||
|
||||
# Clear generated tokens
|
||||
del generated
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -192,11 +182,6 @@ def get_per_token_logps(model, inputs, lengths):
|
||||
).squeeze(-1) # [seq_len]
|
||||
|
||||
per_token_logps.append(token_log_probs)
|
||||
|
||||
# Clean up intermediates
|
||||
del seq_logits, seq_targets, log_probs, token_log_probs
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return per_token_logps
|
||||
|
||||
|
||||
@@ -232,15 +217,9 @@ def grpo_loss(
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
|
||||
del completion_ids
|
||||
mx.metal.clear_cache()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
del prompt_tensor
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Prepare inputs
|
||||
expanded_answers = []
|
||||
@@ -264,25 +243,11 @@ def grpo_loss(
|
||||
mask = mx.ones_like(completion_ids)
|
||||
padded_completions.append(padded_ids)
|
||||
attention_masks.append(mask)
|
||||
|
||||
del completion_ids
|
||||
if padding_length > 0:
|
||||
del padding
|
||||
del mask
|
||||
mx.metal.clear_cache()
|
||||
|
||||
inputs = mx.stack(padded_completions)
|
||||
attention_mask = mx.stack(attention_masks)
|
||||
lengths = attention_mask.sum(axis=1)
|
||||
|
||||
del padded_completions, attention_masks
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Get logits and compute log probabilities
|
||||
logits = model(inputs).astype(mx.float32)
|
||||
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
|
||||
targets = inputs[:, 1:]
|
||||
|
||||
# Current policy probabilities
|
||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||
|
||||
@@ -302,9 +267,6 @@ def grpo_loss(
|
||||
|
||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||
|
||||
del padding
|
||||
mx.metal.clear_cache()
|
||||
|
||||
token_log_probs = mx.stack(padded_log_probs)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||
@@ -360,10 +322,6 @@ def grpo_loss(
|
||||
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
||||
|
||||
# Clean up
|
||||
del all_completions
|
||||
mx.metal.clear_cache()
|
||||
|
||||
metrics = {
|
||||
'total_rewards_mean': mx.mean(rewards),
|
||||
'total_rewards_std': mx.std(rewards),
|
||||
@@ -440,7 +398,7 @@ def evaluate_grpo(
|
||||
group_size: int,
|
||||
max_seq_length,
|
||||
reward_funcs = None,
|
||||
loss: callable = grpo_loss,
|
||||
loss_fn: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches
|
||||
):
|
||||
"""
|
||||
@@ -466,7 +424,7 @@ def evaluate_grpo(
|
||||
),
|
||||
):
|
||||
# Calculate loss for current batch
|
||||
losses, toks, metrics = loss(
|
||||
losses, toks, metrics = loss_fn(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
batch=batch,
|
||||
@@ -518,7 +476,7 @@ def train_grpo(
|
||||
r1_count_xml
|
||||
],
|
||||
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
||||
loss: callable = grpo_loss,
|
||||
loss_fn: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
@@ -546,7 +504,7 @@ def train_grpo(
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
ref_model=ref_model,
|
||||
max_tokens=args.max_seq_length,
|
||||
max_tokens=args.max_completion_length,
|
||||
)
|
||||
|
||||
# All reduce the gradients if running in distributed mode
|
||||
@@ -557,22 +515,23 @@ def train_grpo(
|
||||
|
||||
return loss, toks, metrics
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
accumulated_metrics = {
|
||||
'rewards': 0,
|
||||
'rewards_std': 0,
|
||||
'grouped_rewards': 0,
|
||||
'total_rewards_mean': 0,
|
||||
'total_rewards_std': 0,
|
||||
'grouped_rewards_mean': 0,
|
||||
'grouped_rewards_std': 0,
|
||||
'kl': 0
|
||||
}
|
||||
for i in range(len(reward_funcs)):
|
||||
accumulated_metrics[f'reward_func_{i}_mean'] = 0
|
||||
accumulated_metrics[f'reward_func_{i}_std'] = 0
|
||||
for reward_func in reward_funcs:
|
||||
func_name = reward_func.__name__
|
||||
accumulated_metrics[f'{func_name}_mean'] = 0
|
||||
accumulated_metrics[f'{func_name}_std'] = 0
|
||||
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
@@ -592,7 +551,7 @@ def train_grpo(
|
||||
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
loss=loss,
|
||||
loss_fn=loss_fn,
|
||||
ref_model=ref_model,
|
||||
reward_funcs=reward_funcs,
|
||||
tokenizer=tokenizer,
|
||||
@@ -675,8 +634,8 @@ def train_grpo(
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_name = reward_func.__name__
|
||||
train_metrics_str += (
|
||||
f", Reward func {reward_func.__name__} mean {avg_metrics[f'reward_func_{reward_func.__name__}_mean']:.3f}, "
|
||||
f"Reward func {reward_func.__name__} std {avg_metrics[f'reward_func_{reward_func.__name__}_std']:.3f}"
|
||||
f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
|
||||
f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
|
||||
)
|
||||
|
||||
print(
|
||||
|
||||
Reference in New Issue
Block a user