mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
fix cache handling
This commit is contained in:
parent
7b0141455e
commit
0a09a93454
@ -36,41 +36,6 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
|
||||||
if len(prompt.shape) == 1:
|
|
||||||
prompt = prompt[None, :]
|
|
||||||
|
|
||||||
generated = []
|
|
||||||
current_prompt = prompt[0]
|
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
|
||||||
current_batch = current_prompt[None, :]
|
|
||||||
logits = model(current_batch)
|
|
||||||
token_logits = logits[0, -1]
|
|
||||||
|
|
||||||
if temperature > 0:
|
|
||||||
token_logits = token_logits / temperature
|
|
||||||
|
|
||||||
probs = mx.softmax(token_logits)
|
|
||||||
next_token = mx.random.categorical(probs[None, :])
|
|
||||||
next_token = next_token[0]
|
|
||||||
mx.eval(next_token)
|
|
||||||
|
|
||||||
token_value = next_token.item()
|
|
||||||
generated.append(next_token)
|
|
||||||
|
|
||||||
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
|
||||||
if token_value == tokenizer.eos_token_id:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not generated:
|
|
||||||
return prompt[0]
|
|
||||||
|
|
||||||
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
|
||||||
mx.eval(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def r1_extract_xml_answer(text: str) -> str:
|
def r1_extract_xml_answer(text: str) -> str:
|
||||||
"""Extracts the answer from an XML formatted text string."""
|
"""Extracts the answer from an XML formatted text string."""
|
||||||
try:
|
try:
|
||||||
@ -154,9 +119,48 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
||||||
|
if len(prompt.shape) == 1:
|
||||||
|
prompt = prompt[None, :]
|
||||||
|
if prompt.shape[1] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
||||||
|
output[:prompt.shape[1]] = prompt[0]
|
||||||
|
current_length = prompt.shape[1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for _ in range(max_tokens):
|
||||||
|
current_input = output[:current_length][None, :]
|
||||||
|
logits = model(current_input)
|
||||||
|
token_logits = logits[0, -1]
|
||||||
|
|
||||||
|
if temperature > 0:
|
||||||
|
token_logits /= temperature
|
||||||
|
|
||||||
|
probs = mx.softmax(token_logits)
|
||||||
|
next_token = mx.random.categorical(probs[None, :]).astype(mx.int32)
|
||||||
|
next_token = next_token[0]
|
||||||
|
|
||||||
|
token_value = next_token.item()
|
||||||
|
output[current_length] = token_value
|
||||||
|
current_length += 1
|
||||||
|
|
||||||
|
if token_value == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
if current_length > prompt.shape[1]:
|
||||||
|
result = output[:current_length]
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Generation error: {str(e)}")
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_per_token_logps(model, inputs, lengths):
|
def get_per_token_logps(model, inputs, lengths):
|
||||||
# Get logits from model
|
logits = model(inputs).astype(mx.float16) # [batch_size, seq_len, vocab_size]
|
||||||
logits = model(inputs).astype(mx.float32) # [batch_size, seq_len, vocab_size]
|
|
||||||
# Remove last position as it corresponds to the next token prediction
|
# Remove last position as it corresponds to the next token prediction
|
||||||
logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||||
targets = inputs[:, 1:] # Shift inputs to get targets
|
targets = inputs[:, 1:] # Shift inputs to get targets
|
||||||
@ -182,6 +186,7 @@ def get_per_token_logps(model, inputs, lengths):
|
|||||||
).squeeze(-1) # [seq_len]
|
).squeeze(-1) # [seq_len]
|
||||||
|
|
||||||
per_token_logps.append(token_log_probs)
|
per_token_logps.append(token_log_probs)
|
||||||
|
mx.eval(logits)
|
||||||
return per_token_logps
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
@ -204,22 +209,26 @@ def grpo_loss(
|
|||||||
all_completions = []
|
all_completions = []
|
||||||
all_completion_texts = []
|
all_completion_texts = []
|
||||||
|
|
||||||
for prompt in prompt_tokens:
|
for i in range(0, batch_size, batch_size):
|
||||||
prompt_tensor = mx.array(prompt)
|
batch_prompts = prompt_tokens[i:i+batch_size]
|
||||||
|
for prompt in batch_prompts:
|
||||||
for _ in range(group_size):
|
prompt_tensor = mx.array(prompt)
|
||||||
try:
|
for _ in range(group_size):
|
||||||
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
|
try:
|
||||||
if completion_ids is None:
|
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
|
||||||
|
if completion_ids is not None:
|
||||||
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||||
|
all_completions.append(completion_ids)
|
||||||
|
all_completion_texts.append(completion_text)
|
||||||
|
|
||||||
|
# Clear completion tensors
|
||||||
|
mx.eval(completion_ids)
|
||||||
|
del completion_ids
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Generation error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
mx.metal.clear_cache()
|
||||||
all_completions.append(completion_ids)
|
|
||||||
all_completion_texts.append(completion_text)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Generation error: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
@ -250,6 +259,10 @@ def grpo_loss(
|
|||||||
|
|
||||||
# Current policy probabilities
|
# Current policy probabilities
|
||||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||||
|
|
||||||
|
mx.eval(token_log_probs)
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
|
||||||
# Reference policy probabilities
|
# Reference policy probabilities
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
@ -263,7 +276,7 @@ def grpo_loss(
|
|||||||
|
|
||||||
for i in range(len(token_log_probs)):
|
for i in range(len(token_log_probs)):
|
||||||
seq_len = token_log_probs[i].shape[0]
|
seq_len = token_log_probs[i].shape[0]
|
||||||
padding = mx.zeros((max_len - seq_len,), dtype=mx.float32)
|
padding = mx.zeros((max_len - seq_len,), dtype=mx.float16)
|
||||||
|
|
||||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||||
@ -330,6 +343,7 @@ def grpo_loss(
|
|||||||
'kl': mean_kl,
|
'kl': mean_kl,
|
||||||
**reward_metrics
|
**reward_metrics
|
||||||
}
|
}
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
return loss, sequence_lengths.sum(), metrics
|
return loss, sequence_lengths.sum(), metrics
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user