This commit is contained in:
Goekdeniz-Guelmez 2025-03-01 22:23:33 +01:00
parent 925e11439b
commit 132225a018

View File

@ -61,112 +61,101 @@ def generate_grpo(
temperature: float = 0.8, temperature: float = 0.8,
batch_size: int = 1 batch_size: int = 1
): ):
if len(prompts.shape) == 1: # Store original training state
prompts = prompts[None, :] was_training = model.training
if prompts.shape[1] == 0:
return None
total_samples = prompts.shape[0] * group_size # Set model to eval mode for generation
expanded_prompts = mx.repeat(prompts, group_size, axis=0) model.eval()
end_sequence = mx.array(tokenizer.encode(end_token))
results = []
mx.eval(expanded_prompts)
try: try:
if len(prompts.shape) == 1:
prompts = prompts[None, :]
if prompts.shape[1] == 0:
return None
total_samples = prompts.shape[0] * group_size
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
end_sequence = mx.array(tokenizer.encode(end_token))
results = []
mx.eval(expanded_prompts)
# Process in batches # Process in batches
for batch_start in range(0, total_samples, batch_size): for batch_start in range(0, total_samples, batch_size):
batch_end = min(batch_start + batch_size, total_samples) batch_end = min(batch_start + batch_size, total_samples)
if is_training: if is_training:
# Training mode with batched processing # Training-specific generation logic
batch_inputs = expanded_prompts[batch_start:batch_end] batch_inputs = expanded_prompts[batch_start:batch_end]
batch_tokens = [[] for _ in range(batch_end - batch_start)]
prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)] prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)]
# Initial forward pass for all prompts in batch # Initial forward pass
batch_logits = []
for i, prompt in enumerate(batch_inputs): for i, prompt in enumerate(batch_inputs):
logits = model(prompt[None], cache=prompt_caches[i])[:, -1] logits = model(prompt[None], cache=prompt_caches[i])[:, -1]
batch_logits.append(logits) logits_temp = logits / temperature
mx.eval(batch_logits, prompt_caches)
# Track tokens for each sequence in the batch
batch_tokens = [[] for _ in range(batch_end - batch_start)]
# Initial token generation for all sequences in batch
for i in range(len(batch_logits)):
logits_temp = batch_logits[i] / temperature
next_token = mx.random.categorical(logits_temp) next_token = mx.random.categorical(logits_temp)
token = next_token.item() token = next_token.item()
mx.eval(logits_temp, next_token, token)
batch_tokens[i].append(token) batch_tokens[i].append(token)
del logits, logits_temp, next_token
mx.eval([tokens[-1] for tokens in batch_tokens])
mx.metal.clear_cache()
active_indices = [i for i in range(len(batch_tokens)) if batch_tokens[i][-1] != tokenizer.eos_token_id]
# Generate remaining tokens
for _ in range(max_tokens - 1):
if not active_indices:
break
# Check if this token already completes the sequence
if token == tokenizer.eos_token_id:
continue
else:
# Set up for next token
current_input = mx.array([token])
batch_logits[i] = model(current_input[None], cache=prompt_caches[i])[:, -1]
mx.eval(batch_logits)
active_indices = [i for i, tokens in enumerate(batch_tokens) if tokens[-1] != tokenizer.eos_token_id and len(tokens) < max_tokens]
# Generate tokens until all sequences are complete
while active_indices and max(len(tokens) for tokens in batch_tokens) < max_tokens:
next_active = [] next_active = []
for idx in active_indices: for idx in active_indices:
logits_temp = batch_logits[idx] / temperature current_input = mx.array([batch_tokens[idx][-1]])
logits = model(current_input[None], cache=prompt_caches[idx])[:, -1]
logits_temp = logits / temperature
next_token = mx.random.categorical(logits_temp) next_token = mx.random.categorical(logits_temp)
token = next_token.item() token = next_token.item()
mx.eval(logits_temp, next_token, token)
batch_tokens[idx].append(token) batch_tokens[idx].append(token)
# Check for end sequence # Check for end conditions
is_end = False
if len(batch_tokens[idx]) >= len(end_sequence): if len(batch_tokens[idx]) >= len(end_sequence):
test_sequence = batch_tokens[idx][-len(end_sequence):] test_sequence = batch_tokens[idx][-len(end_sequence):]
is_end = mx.array_equal( is_end = mx.array_equal(mx.array(test_sequence), end_sequence)
mx.array(test_sequence),
end_sequence
)
else:
is_end = False
if is_end or token == tokenizer.eos_token_id or len(batch_tokens[idx]) >= max_tokens: if not (is_end or token == tokenizer.eos_token_id):
# This sequence is done
pass
else:
# Continue with this sequence
next_active.append(idx) next_active.append(idx)
current_input = mx.array([token])
batch_logits[idx] = model(current_input[None], cache=prompt_caches[idx])[:, -1] del logits, logits_temp, next_token, current_input
mx.eval([batch_logits[idx] for idx in next_active]) mx.eval([tokens[-1] for tokens in batch_tokens])
mx.metal.clear_cache()
active_indices = next_active active_indices = next_active
# Clear caches after processing this batch # Clean up caches
for pc in prompt_caches: for pc in prompt_caches:
del pc del pc
# Add batch results to overall results # Process results
for tokens in batch_tokens: for tokens in batch_tokens:
if tokens: if tokens:
# Filter out any special tokens that might appear after the end token # Truncate at end token if present
if len(tokens) >= len(end_sequence): for i in range(len(tokens) - len(end_sequence) + 1):
for i in range(len(tokens) - len(end_sequence) + 1): if mx.array_equal(
if mx.array_equal( mx.array(tokens[i:i+len(end_sequence)]),
mx.array(tokens[i:i+len(end_sequence)]), end_sequence
end_sequence ):
): tokens = tokens[:i+len(end_sequence)]
tokens = tokens[:i+len(end_sequence)] break
break
# Filter out EOS token if it's the last token
if tokens and tokens[-1] == tokenizer.eos_token_id: if tokens and tokens[-1] == tokenizer.eos_token_id:
tokens = tokens[:-1] tokens = tokens[:-1]
# Only add non-empty token lists
if tokens: if tokens:
results.append(mx.array(tokens)) results.append(mx.array(tokens))
del batch_inputs, batch_tokens, prompt_caches
mx.metal.clear_cache()
else: else:
# Non-training mode with batched processing # Non-training mode with batched processing
for idx in range(batch_start, batch_end): for idx in range(batch_start, batch_end):
@ -196,12 +185,17 @@ def generate_grpo(
results.append(mx.array(current_tokens)) results.append(mx.array(current_tokens))
mx.metal.clear_cache() mx.metal.clear_cache()
mx.eval(results) mx.eval(results)
return results return results
except Exception as e: except Exception as e:
print(f"Generation error: {str(e)}") print(f"Generation error: {str(e)}")
return None return None
finally:
# Don't restore training mode - let the caller handle it
pass
def get_per_token_logps(model: nn.Module, inputs, lengths): def get_per_token_logps(model: nn.Module, inputs, lengths):