mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
clean up and readding temperature argument
This commit is contained in:
parent
d653371e3d
commit
d9c4c6e60c
@ -131,7 +131,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
||||
return scores
|
||||
|
||||
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>"):
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
|
||||
if model.training == False:
|
||||
print("Model is in training mode", model.training, "Manually setting to eval mode")
|
||||
model.train()
|
||||
@ -146,8 +146,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||
|
||||
results = []
|
||||
tokens_generated = 0
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
for idx in range(batch_size):
|
||||
@ -160,20 +158,23 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
|
||||
while len(current_tokens) < max_tokens:
|
||||
probs = nn.softmax(logits, axis=-1)
|
||||
logits_temp = logits / temperature
|
||||
probs = nn.softmax(logits_temp, axis=-1)
|
||||
next_token = mx.argmax(probs, axis=-1)
|
||||
token = next_token.item()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
if (len(current_tokens) >= len(end_sequence) and
|
||||
|
||||
test_sequence = current_tokens + [token]
|
||||
if (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(current_tokens[-len(end_sequence):]),
|
||||
mx.array(test_sequence[-len(end_sequence):]),
|
||||
end_sequence
|
||||
)):
|
||||
break
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
current_tokens.append(token)
|
||||
tokens_generated += 1
|
||||
current_input = mx.array([token])
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
mx.eval(current_input)
|
||||
@ -189,15 +190,12 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
current_tokens.append(token)
|
||||
tokens_generated += 1
|
||||
|
||||
if current_tokens:
|
||||
results.append(mx.array(current_tokens))
|
||||
mx.metal.clear_cache()
|
||||
|
||||
mx.eval(results)
|
||||
generation_time = time.perf_counter() - start_time
|
||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
@ -266,7 +264,8 @@ def grpo_loss(
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
is_training=True
|
||||
is_training=True,
|
||||
temperature=temperature
|
||||
)
|
||||
if completions is not None:
|
||||
for completion_ids in completions:
|
||||
|
Loading…
Reference in New Issue
Block a user