little faster generation + prints ot a examplke generatino in validation mode, more optimization in trianing function

This commit is contained in:
Goekdeniz-Guelmez 2025-02-21 16:02:27 +01:00
parent 11c8991476
commit 2f20107d9b
2 changed files with 119 additions and 59 deletions

View File

@ -374,7 +374,9 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
beta=args.beta, beta=args.beta,
group_size=args.group_size, group_size=args.group_size,
epsilon=args.epsilon epsilon=args.epsilon,
temperature=args.temperature,
max_tokens=args.max_seq_length
) )
test_ppl = math.exp(test_loss) test_ppl = math.exp(test_loss)

View File

@ -74,15 +74,14 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
extracted_responses = [r1_extract_xml_answer(r) for r in completions] extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>" has_think = r"<think>.*</think>"
matches = [bool(re.search(pattern, r)) if r else False for r in completions] has_answer = r"<answer>.*</answer>"
matches = [(bool(re.search(has_think, r)) and bool(re.search(has_answer, r))) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches] return [0.5 if match else 0.0 for match in matches]
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
@ -114,44 +113,95 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores return scores
def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature): def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size):
if len(prompt.shape) == 1: if len(prompts.shape) == 1:
prompt = prompt[None, :] prompts = prompts[None, :]
if prompt.shape[1] == 0: if prompts.shape[1] == 0:
return None return None
end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence) start_time = time.perf_counter()
initial_length = prompt.shape[1] tokens_generated = 0
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32) batch_size = prompts.shape[0] * group_size
output[:initial_length] = prompt[0]
current_length = initial_length # Repeat each prompt group_size times
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
end_sequence = mx.array(tokenizer.encode("</answer>"))
end_len = len(end_sequence)
initial_length = prompts.shape[1]
# Initialize output tensor for all sequences
output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
current_lengths = mx.array([initial_length] * batch_size)
temp_factor = 1/temperature if temperature > 0 else float('inf')
try: try:
def sample(logits): not_finished = mx.ones((batch_size,), dtype=mx.bool_)
if temperature > 0:
logits /= temperature
logprobs = logits - mx.logsumexp(logits, keepdims=True)
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
for _ in range(max_tokens): for _ in range(max_tokens):
current_input = output[:current_length][None, :] # Check if all sequences are finished
logits = model(current_input) if not mx.sum(not_finished).item():
token_logits = logits[0, -1]
next_token = sample(token_logits)
token_value = next_token.item()
output[current_length] = token_value
current_length += 1
if token_value == tokenizer.eos_token_id:
break break
if current_length >= end_sequence_length:
last_tokens = output[current_length - end_sequence_length:current_length].tolist() # Get model outputs for all sequences
if last_tokens == end_sequence: max_len = mx.max(current_lengths).item()
break batch_inputs = output[:, :max_len]
if current_length > initial_length: logits = model(batch_inputs)[:, -1]
return output[:current_length]
# Apply mask to logits
logits = logits * mx.expand_dims(not_finished, -1)
# Sample next tokens
logits *= temp_factor
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
next_tokens = mx.random.categorical(logprobs)
# Update outputs for active sequences
for idx in range(batch_size):
if not_finished[idx].item():
curr_len = current_lengths[idx].item()
token_value = next_tokens[idx].item()
# Create new arrays with updates
output = mx.array(output.tolist()) # Make a copy
output[idx, curr_len] = token_value
current_lengths = mx.array([
l + 1 if i == idx else l
for i, l in enumerate(current_lengths.tolist())
])
tokens_generated += 1
# Check end conditions
if token_value == tokenizer.eos_token_id:
not_finished = mx.array([
False if i == idx else nf
for i, nf in enumerate(not_finished.tolist())
])
elif curr_len >= end_len:
last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
if mx.array_equal(last_tokens, end_sequence):
not_finished = mx.array([
False if i == idx else nf
for i, nf in enumerate(not_finished.tolist())
])
if _ % 32 == 0:
mx.eval(output, current_lengths, not_finished)
end_time = time.perf_counter()
generation_time = end_time - start_time
tokens_per_second = tokens_generated / generation_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
# Return only the valid part of each sequence
results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
return results
except Exception as e: except Exception as e:
print(f"Generation error: {str(e)}") print(f"Generation error: {str(e)}")
return None return None
return None
def get_per_token_logps(model: nn.Module, inputs, lengths): def get_per_token_logps(model: nn.Module, inputs, lengths):
@ -185,7 +235,8 @@ def grpo_loss(
epsilon=1e-4, epsilon=1e-4,
max_tokens=64, max_tokens=64,
temperature=1.0, temperature=1.0,
reward_weights=None reward_weights=None,
is_validation=False
): ):
prompt_tokens, _, prompt_text, answer_text = batch prompt_tokens, _, prompt_text, answer_text = batch
batch_size = len(prompt_tokens) batch_size = len(prompt_tokens)
@ -195,22 +246,27 @@ def grpo_loss(
for i in range(0, batch_size, batch_size): for i in range(0, batch_size, batch_size):
batch_prompts = prompt_tokens[i:i+batch_size] batch_prompts = prompt_tokens[i:i+batch_size]
for prompt in batch_prompts: prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts])
prompt_tensor = mx.array(prompt)
for _ in range(group_size): try:
try: completions = generate_grpo(
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature) model,
if completion_ids is not None: prompt_tensor,
completion_text = tokenizer.decode(completion_ids.tolist()) max_tokens,
all_completions.append(completion_ids) tokenizer,
all_completion_texts.append(completion_text) temperature,
mx.eval(completion_ids) group_size
del completion_ids )
except Exception as e:
print(f"Generation error: {e}") if completions is not None:
continue for completion_ids in completions:
completion_text = tokenizer.decode(completion_ids.tolist())
mx.metal.clear_cache() all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
mx.eval(completion_ids)
except Exception as e:
print(f"Generation error: {e}")
continue
expanded_answers = [] expanded_answers = []
expanded_prompts = [] expanded_prompts = []
@ -242,15 +298,12 @@ def grpo_loss(
token_log_probs = get_per_token_logps(model, inputs, lengths) token_log_probs = get_per_token_logps(model, inputs, lengths)
mx.eval(token_log_probs) mx.eval(token_log_probs)
mx.metal.clear_cache()
# Reference policy probabilities
if ref_model is None: if ref_model is None:
ref_token_log_probs = token_log_probs ref_token_log_probs = token_log_probs
else: else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
mx.eval(ref_token_log_probs) mx.eval(ref_token_log_probs)
mx.metal.clear_cache()
max_len = max(x.shape[0] for x in token_log_probs) max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = [] padded_log_probs = []
@ -339,6 +392,10 @@ def grpo_loss(
'kl': mean_kl, 'kl': mean_kl,
**reward_metrics **reward_metrics
} }
if is_validation:
print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
mx.metal.clear_cache() mx.metal.clear_cache()
return loss, sequence_lengths.sum(), metrics return loss, sequence_lengths.sum(), metrics
@ -412,7 +469,7 @@ def evaluate_grpo(
): ):
all_losses = 0 all_losses = 0
ntokens = 0 ntokens = 0
all_metrics = None all_metrics = None
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -434,7 +491,8 @@ def evaluate_grpo(
epsilon=epsilon, epsilon=epsilon,
ref_model=ref_model, ref_model=ref_model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens max_tokens=max_tokens,
is_validation=True
) )
all_losses += losses * toks all_losses += losses * toks