first succesfull training run

This commit is contained in:
Goekdeniz-Guelmez 2025-02-04 09:18:45 +01:00
parent ca32424043
commit 7173840283
3 changed files with 68 additions and 66 deletions

View File

@ -63,11 +63,16 @@ CONFIG_DEFAULTS = {
"config": None,
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
# GRPO args
"reference_model_path": None,
"group_size": 4,
"beta": 0.1,
"epsilon": 1e-4,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"max_completion_length": 512,
"use_chat_template": False,
"use_prompt": False,
}
@ -178,9 +183,15 @@ def build_parser():
parser.add_argument(
"--group-size",
type=int,
help="Number of responses per prompt.",
help="Number of generations.",
default=4,
)
parser.add_argument(
"--max-completion-length",
type=int,
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
default=512,
)
parser.add_argument(
"--beta",
type=float,
@ -193,6 +204,18 @@ def build_parser():
help="The Epsilon for numerical stability.",
default=1e-4,
)
parser.add_argument(
"--use-chat-template",
type=bool,
help="If the model is a Chat model, use the Chat template.",
default=False,
)
parser.add_argument(
"--use-prompt",
type=bool,
help="Rather to use the prompt from teh R1 paper.",
default=False,
)
return parser
@ -262,6 +285,7 @@ def train_model(
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
@ -273,7 +297,7 @@ def train_model(
reference_model, _ = load(args.reference_model_path)
reference_model = reference_model.freeze()
else:
reference_model, _ = None, None
reference_model, _ = load(args.model)
train_grpo(
model=model,

View File

@ -16,14 +16,33 @@ class GRPODataset:
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
answer_key: str = "answer"
answer_key: str = "answer",
use_chat_template: bool = False,
use_prompt: bool = False
):
self._data = []
for item in data:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template(
[
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str}
],
)
answer_tokens = tokenizer.encode(answer_str)
else:
if use_prompt:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
User: {prompt_str}. Assistant: """)
else:
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:

View File

@ -12,8 +12,6 @@ from mlx.utils import tree_flatten
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
from mlx_lm.utils import generate_step
@dataclass
class GRPOTrainingArgs(TrainingArgs):
@ -27,6 +25,9 @@ class GRPOTrainingArgs(TrainingArgs):
epsilon: float = field(
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
)
max_completion_length: int = field(
default=512, metadata={"help": "Number of Generations."}
)
reference_model_path: str = field(
default=None,
metadata={
@ -36,7 +37,6 @@ class GRPOTrainingArgs(TrainingArgs):
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
model.eval()
if len(prompt.shape) == 1:
prompt = prompt[None, :]
@ -58,11 +58,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
token_value = next_token.item()
generated.append(next_token)
# Clear intermediate tensors
del logits, token_logits, probs
mx.metal.clear_cache()
current_prompt = mx.concatenate([current_prompt, next_token[None]])
if token_value == tokenizer.eos_token_id:
break
@ -72,12 +68,6 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
result = mx.concatenate([prompt[0], mx.stack(generated)])
mx.eval(result)
model.train()
# Clear generated tokens
del generated
mx.metal.clear_cache()
return result
@ -192,11 +182,6 @@ def get_per_token_logps(model, inputs, lengths):
).squeeze(-1) # [seq_len]
per_token_logps.append(token_log_probs)
# Clean up intermediates
del seq_logits, seq_targets, log_probs, token_log_probs
mx.metal.clear_cache()
return per_token_logps
@ -232,15 +217,9 @@ def grpo_loss(
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
del completion_ids
mx.metal.clear_cache()
except Exception as e:
print(f"Generation error: {e}")
continue
del prompt_tensor
mx.metal.clear_cache()
# Prepare inputs
expanded_answers = []
@ -264,25 +243,11 @@ def grpo_loss(
mask = mx.ones_like(completion_ids)
padded_completions.append(padded_ids)
attention_masks.append(mask)
del completion_ids
if padding_length > 0:
del padding
del mask
mx.metal.clear_cache()
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
del padded_completions, attention_masks
mx.metal.clear_cache()
# Get logits and compute log probabilities
logits = model(inputs).astype(mx.float32)
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
targets = inputs[:, 1:]
# Current policy probabilities
token_log_probs = get_per_token_logps(model, inputs, lengths)
@ -302,9 +267,6 @@ def grpo_loss(
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
del padding
mx.metal.clear_cache()
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
@ -360,10 +322,6 @@ def grpo_loss(
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
# Clean up
del all_completions
mx.metal.clear_cache()
metrics = {
'total_rewards_mean': mx.mean(rewards),
'total_rewards_std': mx.std(rewards),
@ -440,7 +398,7 @@ def evaluate_grpo(
group_size: int,
max_seq_length,
reward_funcs = None,
loss: callable = grpo_loss,
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
"""
@ -466,7 +424,7 @@ def evaluate_grpo(
),
):
# Calculate loss for current batch
losses, toks, metrics = loss(
losses, toks, metrics = loss_fn(
model=model,
tokenizer=tokenizer,
batch=batch,
@ -518,7 +476,7 @@ def train_grpo(
r1_count_xml
],
args: GRPOTrainingArgs = GRPOTrainingArgs(),
loss: callable = grpo_loss,
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches,
training_callback: TrainingCallback = None,
):
@ -546,7 +504,7 @@ def train_grpo(
group_size=args.group_size,
epsilon=args.epsilon,
ref_model=ref_model,
max_tokens=args.max_seq_length,
max_tokens=args.max_completion_length,
)
# All reduce the gradients if running in distributed mode
@ -557,22 +515,23 @@ def train_grpo(
return loss, toks, metrics
loss_value_and_grad = nn.value_and_grad(model, loss)
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
losses = 0
n_tokens = 0
steps = 0
trained_tokens = 0
accumulated_metrics = {
'rewards': 0,
'rewards_std': 0,
'grouped_rewards': 0,
'total_rewards_mean': 0,
'total_rewards_std': 0,
'grouped_rewards_mean': 0,
'grouped_rewards_std': 0,
'kl': 0
}
for i in range(len(reward_funcs)):
accumulated_metrics[f'reward_func_{i}_mean'] = 0
accumulated_metrics[f'reward_func_{i}_std'] = 0
for reward_func in reward_funcs:
func_name = reward_func.__name__
accumulated_metrics[f'{func_name}_mean'] = 0
accumulated_metrics[f'{func_name}_std'] = 0
start = time.perf_counter()
for it, batch in zip(
@ -592,7 +551,7 @@ def train_grpo(
val_loss, val_ntokens, val_metrics = evaluate_grpo(
model=model,
dataset=val_dataset,
loss=loss,
loss_fn=loss_fn,
ref_model=ref_model,
reward_funcs=reward_funcs,
tokenizer=tokenizer,
@ -675,8 +634,8 @@ def train_grpo(
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
train_metrics_str += (
f", Reward func {reward_func.__name__} mean {avg_metrics[f'reward_func_{reward_func.__name__}_mean']:.3f}, "
f"Reward func {reward_func.__name__} std {avg_metrics[f'reward_func_{reward_func.__name__}_std']:.3f}"
f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
)
print(