mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
clean up and readding temperature argument
This commit is contained in:
parent
d653371e3d
commit
d9c4c6e60c
@ -131,7 +131,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>"):
|
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
|
||||||
if model.training == False:
|
if model.training == False:
|
||||||
print("Model is in training mode", model.training, "Manually setting to eval mode")
|
print("Model is in training mode", model.training, "Manually setting to eval mode")
|
||||||
model.train()
|
model.train()
|
||||||
@ -146,8 +146,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
tokens_generated = 0
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for idx in range(batch_size):
|
for idx in range(batch_size):
|
||||||
@ -160,20 +158,23 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||||
|
|
||||||
while len(current_tokens) < max_tokens:
|
while len(current_tokens) < max_tokens:
|
||||||
probs = nn.softmax(logits, axis=-1)
|
logits_temp = logits / temperature
|
||||||
|
probs = nn.softmax(logits_temp, axis=-1)
|
||||||
next_token = mx.argmax(probs, axis=-1)
|
next_token = mx.argmax(probs, axis=-1)
|
||||||
token = next_token.item()
|
token = next_token.item()
|
||||||
if token == tokenizer.eos_token_id:
|
|
||||||
break
|
test_sequence = current_tokens + [token]
|
||||||
if (len(current_tokens) >= len(end_sequence) and
|
if (len(test_sequence) >= len(end_sequence) and
|
||||||
mx.array_equal(
|
mx.array_equal(
|
||||||
mx.array(current_tokens[-len(end_sequence):]),
|
mx.array(test_sequence[-len(end_sequence):]),
|
||||||
end_sequence
|
end_sequence
|
||||||
)):
|
)):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if token == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
tokens_generated += 1
|
|
||||||
current_input = mx.array([token])
|
current_input = mx.array([token])
|
||||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||||
mx.eval(current_input)
|
mx.eval(current_input)
|
||||||
@ -189,15 +190,12 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
tokens_generated += 1
|
|
||||||
|
|
||||||
if current_tokens:
|
if current_tokens:
|
||||||
results.append(mx.array(current_tokens))
|
results.append(mx.array(current_tokens))
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
mx.eval(results)
|
mx.eval(results)
|
||||||
generation_time = time.perf_counter() - start_time
|
|
||||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -266,7 +264,8 @@ def grpo_loss(
|
|||||||
max_tokens,
|
max_tokens,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
group_size,
|
group_size,
|
||||||
is_training=True
|
is_training=True,
|
||||||
|
temperature=temperature
|
||||||
)
|
)
|
||||||
if completions is not None:
|
if completions is not None:
|
||||||
for completion_ids in completions:
|
for completion_ids in completions:
|
||||||
|
Loading…
Reference in New Issue
Block a user