mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-12 20:26:45 +08:00
updates
This commit is contained in:
parent
925e11439b
commit
132225a018
@ -61,112 +61,101 @@ def generate_grpo(
|
|||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
):
|
):
|
||||||
if len(prompts.shape) == 1:
|
# Store original training state
|
||||||
prompts = prompts[None, :]
|
was_training = model.training
|
||||||
if prompts.shape[1] == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
total_samples = prompts.shape[0] * group_size
|
# Set model to eval mode for generation
|
||||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
model.eval()
|
||||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
|
||||||
results = []
|
|
||||||
mx.eval(expanded_prompts)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if len(prompts.shape) == 1:
|
||||||
|
prompts = prompts[None, :]
|
||||||
|
if prompts.shape[1] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
total_samples = prompts.shape[0] * group_size
|
||||||
|
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||||
|
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||||
|
results = []
|
||||||
|
mx.eval(expanded_prompts)
|
||||||
|
|
||||||
# Process in batches
|
# Process in batches
|
||||||
for batch_start in range(0, total_samples, batch_size):
|
for batch_start in range(0, total_samples, batch_size):
|
||||||
batch_end = min(batch_start + batch_size, total_samples)
|
batch_end = min(batch_start + batch_size, total_samples)
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
# Training mode with batched processing
|
# Training-specific generation logic
|
||||||
batch_inputs = expanded_prompts[batch_start:batch_end]
|
batch_inputs = expanded_prompts[batch_start:batch_end]
|
||||||
|
batch_tokens = [[] for _ in range(batch_end - batch_start)]
|
||||||
prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)]
|
prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)]
|
||||||
|
|
||||||
# Initial forward pass for all prompts in batch
|
# Initial forward pass
|
||||||
batch_logits = []
|
|
||||||
for i, prompt in enumerate(batch_inputs):
|
for i, prompt in enumerate(batch_inputs):
|
||||||
logits = model(prompt[None], cache=prompt_caches[i])[:, -1]
|
logits = model(prompt[None], cache=prompt_caches[i])[:, -1]
|
||||||
batch_logits.append(logits)
|
logits_temp = logits / temperature
|
||||||
mx.eval(batch_logits, prompt_caches)
|
|
||||||
|
|
||||||
# Track tokens for each sequence in the batch
|
|
||||||
batch_tokens = [[] for _ in range(batch_end - batch_start)]
|
|
||||||
|
|
||||||
# Initial token generation for all sequences in batch
|
|
||||||
for i in range(len(batch_logits)):
|
|
||||||
logits_temp = batch_logits[i] / temperature
|
|
||||||
next_token = mx.random.categorical(logits_temp)
|
next_token = mx.random.categorical(logits_temp)
|
||||||
token = next_token.item()
|
token = next_token.item()
|
||||||
mx.eval(logits_temp, next_token, token)
|
|
||||||
batch_tokens[i].append(token)
|
batch_tokens[i].append(token)
|
||||||
|
del logits, logits_temp, next_token
|
||||||
|
|
||||||
|
mx.eval([tokens[-1] for tokens in batch_tokens])
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
active_indices = [i for i in range(len(batch_tokens)) if batch_tokens[i][-1] != tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
# Generate remaining tokens
|
||||||
|
for _ in range(max_tokens - 1):
|
||||||
|
if not active_indices:
|
||||||
|
break
|
||||||
|
|
||||||
# Check if this token already completes the sequence
|
|
||||||
if token == tokenizer.eos_token_id:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# Set up for next token
|
|
||||||
current_input = mx.array([token])
|
|
||||||
batch_logits[i] = model(current_input[None], cache=prompt_caches[i])[:, -1]
|
|
||||||
|
|
||||||
mx.eval(batch_logits)
|
|
||||||
active_indices = [i for i, tokens in enumerate(batch_tokens) if tokens[-1] != tokenizer.eos_token_id and len(tokens) < max_tokens]
|
|
||||||
|
|
||||||
# Generate tokens until all sequences are complete
|
|
||||||
while active_indices and max(len(tokens) for tokens in batch_tokens) < max_tokens:
|
|
||||||
next_active = []
|
next_active = []
|
||||||
for idx in active_indices:
|
for idx in active_indices:
|
||||||
logits_temp = batch_logits[idx] / temperature
|
current_input = mx.array([batch_tokens[idx][-1]])
|
||||||
|
logits = model(current_input[None], cache=prompt_caches[idx])[:, -1]
|
||||||
|
logits_temp = logits / temperature
|
||||||
next_token = mx.random.categorical(logits_temp)
|
next_token = mx.random.categorical(logits_temp)
|
||||||
token = next_token.item()
|
token = next_token.item()
|
||||||
mx.eval(logits_temp, next_token, token)
|
|
||||||
batch_tokens[idx].append(token)
|
batch_tokens[idx].append(token)
|
||||||
|
|
||||||
# Check for end sequence
|
# Check for end conditions
|
||||||
|
is_end = False
|
||||||
if len(batch_tokens[idx]) >= len(end_sequence):
|
if len(batch_tokens[idx]) >= len(end_sequence):
|
||||||
test_sequence = batch_tokens[idx][-len(end_sequence):]
|
test_sequence = batch_tokens[idx][-len(end_sequence):]
|
||||||
is_end = mx.array_equal(
|
is_end = mx.array_equal(mx.array(test_sequence), end_sequence)
|
||||||
mx.array(test_sequence),
|
|
||||||
end_sequence
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
is_end = False
|
|
||||||
|
|
||||||
if is_end or token == tokenizer.eos_token_id or len(batch_tokens[idx]) >= max_tokens:
|
if not (is_end or token == tokenizer.eos_token_id):
|
||||||
# This sequence is done
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Continue with this sequence
|
|
||||||
next_active.append(idx)
|
next_active.append(idx)
|
||||||
current_input = mx.array([token])
|
|
||||||
batch_logits[idx] = model(current_input[None], cache=prompt_caches[idx])[:, -1]
|
del logits, logits_temp, next_token, current_input
|
||||||
|
|
||||||
mx.eval([batch_logits[idx] for idx in next_active])
|
mx.eval([tokens[-1] for tokens in batch_tokens])
|
||||||
|
mx.metal.clear_cache()
|
||||||
active_indices = next_active
|
active_indices = next_active
|
||||||
|
|
||||||
# Clear caches after processing this batch
|
# Clean up caches
|
||||||
for pc in prompt_caches:
|
for pc in prompt_caches:
|
||||||
del pc
|
del pc
|
||||||
|
|
||||||
# Add batch results to overall results
|
# Process results
|
||||||
for tokens in batch_tokens:
|
for tokens in batch_tokens:
|
||||||
if tokens:
|
if tokens:
|
||||||
# Filter out any special tokens that might appear after the end token
|
# Truncate at end token if present
|
||||||
if len(tokens) >= len(end_sequence):
|
for i in range(len(tokens) - len(end_sequence) + 1):
|
||||||
for i in range(len(tokens) - len(end_sequence) + 1):
|
if mx.array_equal(
|
||||||
if mx.array_equal(
|
mx.array(tokens[i:i+len(end_sequence)]),
|
||||||
mx.array(tokens[i:i+len(end_sequence)]),
|
end_sequence
|
||||||
end_sequence
|
):
|
||||||
):
|
tokens = tokens[:i+len(end_sequence)]
|
||||||
tokens = tokens[:i+len(end_sequence)]
|
break
|
||||||
break
|
|
||||||
|
|
||||||
# Filter out EOS token if it's the last token
|
|
||||||
if tokens and tokens[-1] == tokenizer.eos_token_id:
|
if tokens and tokens[-1] == tokenizer.eos_token_id:
|
||||||
tokens = tokens[:-1]
|
tokens = tokens[:-1]
|
||||||
|
|
||||||
# Only add non-empty token lists
|
|
||||||
if tokens:
|
if tokens:
|
||||||
results.append(mx.array(tokens))
|
results.append(mx.array(tokens))
|
||||||
|
|
||||||
|
del batch_inputs, batch_tokens, prompt_caches
|
||||||
|
mx.metal.clear_cache()
|
||||||
else:
|
else:
|
||||||
# Non-training mode with batched processing
|
# Non-training mode with batched processing
|
||||||
for idx in range(batch_start, batch_end):
|
for idx in range(batch_start, batch_end):
|
||||||
@ -196,12 +185,17 @@ def generate_grpo(
|
|||||||
results.append(mx.array(current_tokens))
|
results.append(mx.array(current_tokens))
|
||||||
|
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
mx.eval(results)
|
mx.eval(results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Generation error: {str(e)}")
|
print(f"Generation error: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Don't restore training mode - let the caller handle it
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||||
|
Loading…
Reference in New Issue
Block a user