mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
fix
This commit is contained in:
parent
710bc1490e
commit
c51b0a2715
@ -76,10 +76,30 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
|
|||||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
has_think = r"<think>.*</think>"
|
|
||||||
has_answer = r"<answer>.*</answer>"
|
scores = []
|
||||||
matches = [(bool(re.search(has_think, r)) and bool(re.search(has_answer, r))) if r else False for r in completions]
|
for completion in completions:
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
if not completion:
|
||||||
|
scores.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
reason_start = completion.find("<think>")
|
||||||
|
reason_end = completion.find("</think>")
|
||||||
|
answer_start = completion.find("<answer>")
|
||||||
|
answer_end = completion.find("</answer>")
|
||||||
|
|
||||||
|
if (reason_start != -1 and reason_end != -1 and
|
||||||
|
answer_start != -1 and answer_end != -1 and
|
||||||
|
reason_start < reason_end < answer_start < answer_end):
|
||||||
|
reason_content = completion[reason_start+13:reason_end].strip()
|
||||||
|
answer_content = completion[answer_start+8:answer_end].strip()
|
||||||
|
if reason_content and answer_content:
|
||||||
|
scores.append(0.5)
|
||||||
|
continue
|
||||||
|
|
||||||
|
scores.append(0.0)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
@ -118,53 +138,49 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
if prompts.shape[1] == 0:
|
if prompts.shape[1] == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
model.eval()
|
||||||
batch_size = prompts.shape[0] * group_size
|
batch_size = prompts.shape[0] * group_size
|
||||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||||
mx.eval(expanded_prompts)
|
end_sequence = mx.array(tokenizer.encode("</answer>"))
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
tokens_generated = 0
|
tokens_generated = 0
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
for idx in range(batch_size):
|
try:
|
||||||
current_prompt = expanded_prompts[idx:idx+1]
|
for idx in range(batch_size):
|
||||||
mx.eval(current_prompt)
|
current_tokens = []
|
||||||
|
|
||||||
current_tokens = []
|
|
||||||
try:
|
|
||||||
if is_training:
|
if is_training:
|
||||||
# Initialize with prompt
|
current_input = expanded_prompts[idx]
|
||||||
current_input = current_prompt[0]
|
|
||||||
mx.eval(current_input)
|
|
||||||
|
|
||||||
while len(current_tokens) < max_tokens:
|
while len(current_tokens) < max_tokens:
|
||||||
# Generate one token at a time
|
logits = model(current_input[None])[:, -1]
|
||||||
logits = model(current_input[None])
|
next_token = mx.argmax(logits, axis=-1)
|
||||||
next_token = mx.random.categorical(logits[:, -1, :])
|
|
||||||
token = next_token.item()
|
token = next_token.item()
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
tokens_generated += 1
|
tokens_generated += 1
|
||||||
|
|
||||||
# Clear intermediate results
|
|
||||||
mx.eval(next_token)
|
|
||||||
del logits
|
|
||||||
|
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Update input for next iteration
|
if (len(current_tokens) >= len(end_sequence) and
|
||||||
current_input = mx.array([token])
|
mx.array_equal(
|
||||||
mx.eval(current_input)
|
mx.array(current_tokens[-len(end_sequence):]),
|
||||||
|
end_sequence
|
||||||
|
)):
|
||||||
|
break
|
||||||
|
|
||||||
# Clear cache periodically
|
current_input = mx.concatenate([current_input, mx.array([token])])
|
||||||
if len(current_tokens) % 8 == 0:
|
|
||||||
|
if len(current_tokens) % 32 == 0:
|
||||||
|
mx.eval(current_input)
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
else:
|
else:
|
||||||
generator = generate_step(
|
generator = generate_step(
|
||||||
current_prompt[0],
|
expanded_prompts[idx],
|
||||||
model,
|
model,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
sampler=lambda x: mx.random.categorical(x)
|
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
for token, _ in generator:
|
for token, _ in generator:
|
||||||
@ -174,28 +190,18 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
break
|
break
|
||||||
|
|
||||||
if current_tokens:
|
if current_tokens:
|
||||||
token_array = mx.array(current_tokens)
|
results.append(mx.array(current_tokens))
|
||||||
mx.eval(token_array)
|
mx.metal.clear_cache()
|
||||||
results.append(token_array)
|
|
||||||
del token_array
|
|
||||||
|
|
||||||
except Exception as e:
|
mx.eval(results)
|
||||||
print(f"Generation failed for sequence {idx}: {e}")
|
generation_time = time.perf_counter() - start_time
|
||||||
continue
|
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
||||||
|
return results
|
||||||
|
|
||||||
mx.metal.clear_cache()
|
except Exception as e:
|
||||||
|
print(f"Generation error: {str(e)}")
|
||||||
if not results:
|
|
||||||
print("No successful generations")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mx.eval(results)
|
|
||||||
|
|
||||||
generation_time = time.perf_counter() - start_time
|
|
||||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||||
logits = model(inputs).astype(mx.float16)
|
logits = model(inputs).astype(mx.float16)
|
||||||
@ -243,15 +249,23 @@ def grpo_loss(
|
|||||||
prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts])
|
prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
completions = generate_grpo(
|
if is_validation:
|
||||||
model,
|
completions = generate_grpo(
|
||||||
prompt_tensor,
|
model,
|
||||||
max_tokens,
|
prompt_tensor,
|
||||||
tokenizer,
|
max_tokens,
|
||||||
group_size,
|
tokenizer,
|
||||||
True
|
group_size
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
completions = generate_grpo(
|
||||||
|
model,
|
||||||
|
prompt_tensor,
|
||||||
|
max_tokens,
|
||||||
|
tokenizer,
|
||||||
|
group_size,
|
||||||
|
is_training=True
|
||||||
|
)
|
||||||
if completions is not None:
|
if completions is not None:
|
||||||
for completion_ids in completions:
|
for completion_ids in completions:
|
||||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||||
@ -262,8 +276,6 @@ def grpo_loss(
|
|||||||
print(f"Generation error: {e}")
|
print(f"Generation error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
expanded_prompts = []
|
expanded_prompts = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
Loading…
Reference in New Issue
Block a user