mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 11:21:32 +08:00
little faster generation + prints ot a examplke generatino in validation mode, more optimization in trianing function
This commit is contained in:
parent
11c8991476
commit
2f20107d9b
@ -374,7 +374,9 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon
|
||||
epsilon=args.epsilon,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_seq_length
|
||||
)
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
@ -74,15 +74,14 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
|
||||
|
||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||
has_think = r"<think>.*</think>"
|
||||
has_answer = r"<answer>.*</answer>"
|
||||
matches = [(bool(re.search(has_think, r)) and bool(re.search(has_answer, r))) if r else False for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
@ -114,44 +113,95 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
||||
return scores
|
||||
|
||||
|
||||
def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
|
||||
if len(prompt.shape) == 1:
|
||||
prompt = prompt[None, :]
|
||||
if prompt.shape[1] == 0:
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size):
|
||||
if len(prompts.shape) == 1:
|
||||
prompts = prompts[None, :]
|
||||
if prompts.shape[1] == 0:
|
||||
return None
|
||||
end_sequence = tokenizer.encode("</answer>")
|
||||
end_sequence_length = len(end_sequence)
|
||||
initial_length = prompt.shape[1]
|
||||
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
|
||||
output[:initial_length] = prompt[0]
|
||||
current_length = initial_length
|
||||
|
||||
start_time = time.perf_counter()
|
||||
tokens_generated = 0
|
||||
batch_size = prompts.shape[0] * group_size
|
||||
|
||||
# Repeat each prompt group_size times
|
||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||
|
||||
end_sequence = mx.array(tokenizer.encode("</answer>"))
|
||||
end_len = len(end_sequence)
|
||||
initial_length = prompts.shape[1]
|
||||
|
||||
# Initialize output tensor for all sequences
|
||||
output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
|
||||
output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
|
||||
current_lengths = mx.array([initial_length] * batch_size)
|
||||
|
||||
temp_factor = 1/temperature if temperature > 0 else float('inf')
|
||||
|
||||
try:
|
||||
def sample(logits):
|
||||
if temperature > 0:
|
||||
logits /= temperature
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
|
||||
not_finished = mx.ones((batch_size,), dtype=mx.bool_)
|
||||
|
||||
for _ in range(max_tokens):
|
||||
current_input = output[:current_length][None, :]
|
||||
logits = model(current_input)
|
||||
token_logits = logits[0, -1]
|
||||
next_token = sample(token_logits)
|
||||
token_value = next_token.item()
|
||||
output[current_length] = token_value
|
||||
current_length += 1
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
# Check if all sequences are finished
|
||||
if not mx.sum(not_finished).item():
|
||||
break
|
||||
if current_length >= end_sequence_length:
|
||||
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
||||
if last_tokens == end_sequence:
|
||||
break
|
||||
if current_length > initial_length:
|
||||
return output[:current_length]
|
||||
|
||||
# Get model outputs for all sequences
|
||||
max_len = mx.max(current_lengths).item()
|
||||
batch_inputs = output[:, :max_len]
|
||||
logits = model(batch_inputs)[:, -1]
|
||||
|
||||
# Apply mask to logits
|
||||
logits = logits * mx.expand_dims(not_finished, -1)
|
||||
|
||||
# Sample next tokens
|
||||
logits *= temp_factor
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
next_tokens = mx.random.categorical(logprobs)
|
||||
|
||||
# Update outputs for active sequences
|
||||
for idx in range(batch_size):
|
||||
if not_finished[idx].item():
|
||||
curr_len = current_lengths[idx].item()
|
||||
token_value = next_tokens[idx].item()
|
||||
|
||||
# Create new arrays with updates
|
||||
output = mx.array(output.tolist()) # Make a copy
|
||||
output[idx, curr_len] = token_value
|
||||
current_lengths = mx.array([
|
||||
l + 1 if i == idx else l
|
||||
for i, l in enumerate(current_lengths.tolist())
|
||||
])
|
||||
tokens_generated += 1
|
||||
|
||||
# Check end conditions
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
not_finished = mx.array([
|
||||
False if i == idx else nf
|
||||
for i, nf in enumerate(not_finished.tolist())
|
||||
])
|
||||
elif curr_len >= end_len:
|
||||
last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
|
||||
if mx.array_equal(last_tokens, end_sequence):
|
||||
not_finished = mx.array([
|
||||
False if i == idx else nf
|
||||
for i, nf in enumerate(not_finished.tolist())
|
||||
])
|
||||
|
||||
if _ % 32 == 0:
|
||||
mx.eval(output, current_lengths, not_finished)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
generation_time = end_time - start_time
|
||||
tokens_per_second = tokens_generated / generation_time
|
||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
|
||||
|
||||
# Return only the valid part of each sequence
|
||||
results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"Generation error: {str(e)}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
@ -185,7 +235,8 @@ def grpo_loss(
|
||||
epsilon=1e-4,
|
||||
max_tokens=64,
|
||||
temperature=1.0,
|
||||
reward_weights=None
|
||||
reward_weights=None,
|
||||
is_validation=False
|
||||
):
|
||||
prompt_tokens, _, prompt_text, answer_text = batch
|
||||
batch_size = len(prompt_tokens)
|
||||
@ -195,22 +246,27 @@ def grpo_loss(
|
||||
|
||||
for i in range(0, batch_size, batch_size):
|
||||
batch_prompts = prompt_tokens[i:i+batch_size]
|
||||
for prompt in batch_prompts:
|
||||
prompt_tensor = mx.array(prompt)
|
||||
for _ in range(group_size):
|
||||
try:
|
||||
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
|
||||
if completion_ids is not None:
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
mx.eval(completion_ids)
|
||||
del completion_ids
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
mx.metal.clear_cache()
|
||||
prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts])
|
||||
|
||||
try:
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
temperature,
|
||||
group_size
|
||||
)
|
||||
|
||||
if completions is not None:
|
||||
for completion_ids in completions:
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
mx.eval(completion_ids)
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
@ -242,15 +298,12 @@ def grpo_loss(
|
||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||
|
||||
mx.eval(token_log_probs)
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Reference policy probabilities
|
||||
|
||||
if ref_model is None:
|
||||
ref_token_log_probs = token_log_probs
|
||||
else:
|
||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||
mx.eval(ref_token_log_probs)
|
||||
mx.metal.clear_cache()
|
||||
|
||||
max_len = max(x.shape[0] for x in token_log_probs)
|
||||
padded_log_probs = []
|
||||
@ -339,6 +392,10 @@ def grpo_loss(
|
||||
'kl': mean_kl,
|
||||
**reward_metrics
|
||||
}
|
||||
|
||||
if is_validation:
|
||||
print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
@ -412,7 +469,7 @@ def evaluate_grpo(
|
||||
):
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
all_metrics = None
|
||||
all_metrics = None
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
@ -434,7 +491,8 @@ def evaluate_grpo(
|
||||
epsilon=epsilon,
|
||||
ref_model=ref_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
max_tokens=max_tokens,
|
||||
is_validation=True
|
||||
)
|
||||
|
||||
all_losses += losses * toks
|
||||
|
Loading…
Reference in New Issue
Block a user