mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
fix wrong generation in train
This commit is contained in:
@@ -132,21 +132,15 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
|
|
||||||
|
|
||||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
|
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
|
||||||
if model.training == False:
|
|
||||||
print("Model is in training mode", model.training, "Manually setting to eval mode")
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
if len(prompts.shape) == 1:
|
if len(prompts.shape) == 1:
|
||||||
prompts = prompts[None, :]
|
prompts = prompts[None, :]
|
||||||
if prompts.shape[1] == 0:
|
if prompts.shape[1] == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
batch_size = prompts.shape[0] * group_size
|
batch_size = prompts.shape[0] * group_size
|
||||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
mx.eval(expanded_prompts)
|
||||||
try:
|
try:
|
||||||
for idx in range(batch_size):
|
for idx in range(batch_size):
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
@@ -154,8 +148,8 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
if is_training:
|
if is_training:
|
||||||
current_input = expanded_prompts[idx]
|
current_input = expanded_prompts[idx]
|
||||||
prompt_cache = cache.make_prompt_cache(model)
|
prompt_cache = cache.make_prompt_cache(model)
|
||||||
|
|
||||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||||
|
mx.eval(logits, prompt_cache)
|
||||||
|
|
||||||
while len(current_tokens) < max_tokens:
|
while len(current_tokens) < max_tokens:
|
||||||
logits_temp = logits / temperature
|
logits_temp = logits / temperature
|
||||||
@@ -169,16 +163,16 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
mx.array(test_sequence[-len(end_sequence):]),
|
mx.array(test_sequence[-len(end_sequence):]),
|
||||||
end_sequence
|
end_sequence
|
||||||
)):
|
)):
|
||||||
|
current_tokens.append(token)
|
||||||
break
|
break
|
||||||
|
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
current_input = mx.array([token])
|
current_input = mx.array([token])
|
||||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||||
mx.eval(current_input)
|
mx.eval(current_input, logits, probs, next_token)
|
||||||
mx.metal.clear_cache()
|
|
||||||
else:
|
else:
|
||||||
generator = generate_step(
|
generator = generate_step(
|
||||||
expanded_prompts[idx],
|
expanded_prompts[idx],
|
||||||
|
|||||||
Reference in New Issue
Block a user