mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
generation should be fixed now
This commit is contained in:
parent
46d6146102
commit
0bc2a881ad
@ -57,228 +57,6 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_step(
|
|
||||||
prompt: mx.array,
|
|
||||||
model: nn.Module,
|
|
||||||
*,
|
|
||||||
max_tokens: int = 256,
|
|
||||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
|
||||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
|
||||||
max_kv_size: Optional[int] = None,
|
|
||||||
prompt_cache: Optional[Any] = None,
|
|
||||||
prefill_step_size: int = 512,
|
|
||||||
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
||||||
"""
|
|
||||||
A generator producing token ids based on the given prompt from the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (mx.array): The input prompt.
|
|
||||||
model (nn.Module): The model to use for generation.
|
|
||||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
|
||||||
generator. Default: ``256``.
|
|
||||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
|
||||||
token from a vector of log probabilities. Default: ``None``.
|
|
||||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
|
||||||
A list of functions that take tokens and logits and return the processed
|
|
||||||
logits. Default: ``None``.
|
|
||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
|
||||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
|
||||||
provided, the cache will be updated in place.
|
|
||||||
prefill_step_size (int): Step size for processing the prompt.
|
|
||||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
|
||||||
None implies no cache quantization. Default: ``None``.
|
|
||||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
|
||||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
|
||||||
when ``kv_bits`` is non-None. Default: ``0``.
|
|
||||||
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
|
|
||||||
prompt tokens processed so far and the total number of prompt tokens.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
|
||||||
"""
|
|
||||||
|
|
||||||
y = prompt
|
|
||||||
tokens = None
|
|
||||||
|
|
||||||
# Create the KV cache for generation
|
|
||||||
if prompt_cache is None:
|
|
||||||
prompt_cache = cache.make_prompt_cache(
|
|
||||||
model,
|
|
||||||
max_kv_size=max_kv_size,
|
|
||||||
)
|
|
||||||
elif len(prompt_cache) != len(model.layers):
|
|
||||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
|
||||||
|
|
||||||
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
|
||||||
|
|
||||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
|
||||||
|
|
||||||
def _step(y):
|
|
||||||
with mx.stream(generation_stream):
|
|
||||||
logits = model(y[None], cache=prompt_cache)
|
|
||||||
logits = logits[:, -1, :]
|
|
||||||
|
|
||||||
if logits_processors:
|
|
||||||
nonlocal tokens
|
|
||||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
|
||||||
|
|
||||||
for processor in logits_processors:
|
|
||||||
logits = processor(tokens, logits)
|
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
|
||||||
y = sampler(logprobs)
|
|
||||||
return mx.stop_gradient(y), mx.stop_gradient(logprobs.squeeze(0))
|
|
||||||
|
|
||||||
with mx.stream(generation_stream):
|
|
||||||
total_prompt_tokens = y.size
|
|
||||||
prompt_processed_tokens = 0
|
|
||||||
while y.size > prefill_step_size:
|
|
||||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
|
||||||
mx.eval([c.state for c in prompt_cache])
|
|
||||||
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
|
||||||
prompt_processed_tokens += prefill_step_size
|
|
||||||
y = y[prefill_step_size:]
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
y, logprobs = _step(y)
|
|
||||||
|
|
||||||
mx.eval(y, logprobs)
|
|
||||||
n = 0
|
|
||||||
while True:
|
|
||||||
if n != max_tokens:
|
|
||||||
next_y, next_logprobs = _step(y)
|
|
||||||
mx.eval(next_y, next_logprobs)
|
|
||||||
if n == 0:
|
|
||||||
mx.eval(y)
|
|
||||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
|
||||||
if n == max_tokens:
|
|
||||||
break
|
|
||||||
yield y.item(), logprobs
|
|
||||||
if n % 256 == 0:
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
y, logprobs = next_y, next_logprobs
|
|
||||||
n += 1
|
|
||||||
|
|
||||||
|
|
||||||
def generate_grpo(
|
|
||||||
model: nn.Module,
|
|
||||||
prompts,
|
|
||||||
max_tokens,
|
|
||||||
tokenizer,
|
|
||||||
group_size,
|
|
||||||
end_token: str = "</answer>",
|
|
||||||
temperature: float = 0.8,
|
|
||||||
batch_size: int = 1,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
if len(prompts.shape) == 1:
|
|
||||||
prompts = prompts[None, :]
|
|
||||||
if prompts.shape[1] == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
total_samples = prompts.shape[0] * group_size
|
|
||||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
|
||||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
|
||||||
results = []
|
|
||||||
mx.eval(expanded_prompts, results)
|
|
||||||
|
|
||||||
print(f"Setup time: {time.time() - start_time:.2f}s")
|
|
||||||
print(f"Generating {total_samples} samples with max_tokens={max_tokens}")
|
|
||||||
|
|
||||||
total_tokens_generated = 0
|
|
||||||
generation_start_time = time.time()
|
|
||||||
|
|
||||||
# Process in batches
|
|
||||||
for batch_start in range(0, total_samples, batch_size):
|
|
||||||
batch_end = min(batch_start + batch_size, total_samples)
|
|
||||||
batch_time = time.time()
|
|
||||||
print(
|
|
||||||
f"Starting batch {batch_start//batch_size + 1}/{(total_samples + batch_size - 1)//batch_size}: samples {batch_start}-{batch_end-1}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Custom sampler function that handles temperature
|
|
||||||
def temp_sampler(logits):
|
|
||||||
return mx.random.categorical(logits / temperature)
|
|
||||||
|
|
||||||
# Batched processing
|
|
||||||
for idx in range(batch_start, batch_end):
|
|
||||||
sample_start_time = time.time()
|
|
||||||
current_tokens = []
|
|
||||||
prompt_cache = cache.make_prompt_cache(model)
|
|
||||||
|
|
||||||
# The generate_step function yields one token at a time
|
|
||||||
# We'll collect tokens until we hit max_tokens or a stopping condition
|
|
||||||
for i, (token, _) in enumerate(
|
|
||||||
generate_step(
|
|
||||||
expanded_prompts[idx],
|
|
||||||
model,
|
|
||||||
max_tokens=max_tokens, # This is the maximum number of steps
|
|
||||||
sampler=temp_sampler,
|
|
||||||
prompt_cache=prompt_cache,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
# Check for EOS token
|
|
||||||
if token == tokenizer.eos_token_id:
|
|
||||||
break
|
|
||||||
|
|
||||||
current_tokens.append(token)
|
|
||||||
|
|
||||||
print(token)
|
|
||||||
|
|
||||||
# Check for end token
|
|
||||||
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
|
||||||
mx.array(current_tokens[-len(end_sequence) :]), end_sequence
|
|
||||||
):
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check if we've reached the maximum number of tokens
|
|
||||||
if i >= max_tokens - 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
mx.eval(current_tokens)
|
|
||||||
|
|
||||||
if current_tokens:
|
|
||||||
results.append(mx.array(current_tokens))
|
|
||||||
total_tokens_generated += len(current_tokens)
|
|
||||||
|
|
||||||
sample_time = time.time() - sample_start_time
|
|
||||||
tokens_per_second = (
|
|
||||||
len(current_tokens) / sample_time if sample_time > 0 else 0
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f" Sample {idx}: Generated {len(current_tokens)} tokens in {sample_time:.2f}s ({tokens_per_second:.2f} tokens/sec)"
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_time = time.time() - batch_time
|
|
||||||
print(f"Batch completed in {batch_time:.2f}s")
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
generation_time = time.time() - generation_start_time
|
|
||||||
avg_tokens_per_second = (
|
|
||||||
total_tokens_generated / generation_time if generation_time > 0 else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Generation complete: {total_tokens_generated} tokens in {generation_time:.2f}s"
|
|
||||||
)
|
|
||||||
print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec")
|
|
||||||
|
|
||||||
results = [mx.stop_gradient(r) for r in results]
|
|
||||||
mx.eval(results)
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Generation error: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||||
logits = model(inputs).astype(mx.float16)
|
logits = model(inputs).astype(mx.float16)
|
||||||
logits = logits[:, :-1, :]
|
logits = logits[:, :-1, :]
|
||||||
@ -297,74 +75,123 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
|||||||
return per_token_logps
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
def generate_without_gradients(
|
def generate_step(
|
||||||
|
prompt: mx.array,
|
||||||
|
model: nn.Module,
|
||||||
|
max_tokens: int = 256,
|
||||||
|
sampler: Optional[Callable] = None,
|
||||||
|
logits_processors: Optional[List[Callable]] = None,
|
||||||
|
max_kv_size: Optional[int] = None,
|
||||||
|
prompt_cache: Optional[Any] = None,
|
||||||
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
|
tokens = None
|
||||||
|
y = prompt
|
||||||
|
if prompt_cache is None:
|
||||||
|
prompt_cache = cache.make_prompt_cache(model, max_kv_size=max_kv_size)
|
||||||
|
def _step(y):
|
||||||
|
with mx.stream(generation_stream):
|
||||||
|
logits = model(y[None], cache=prompt_cache)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
if logits_processors:
|
||||||
|
nonlocal tokens
|
||||||
|
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||||
|
for processor in logits_processors:
|
||||||
|
logits = processor(tokens, logits)
|
||||||
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
|
next_token = sampler(logprobs)
|
||||||
|
return mx.stop_gradient(next_token), mx.stop_gradient(logprobs.squeeze(0))
|
||||||
|
try:
|
||||||
|
with mx.stream(generation_stream):
|
||||||
|
y, logprobs = _step(y)
|
||||||
|
mx.eval(y, logprobs)
|
||||||
|
for n in range(max_tokens):
|
||||||
|
yield y.item(), logprobs
|
||||||
|
next_y, next_logprobs = _step(y)
|
||||||
|
mx.eval(next_y, next_logprobs)
|
||||||
|
y, logprobs = next_y, next_logprobs
|
||||||
|
if (n + 1) % 32 == 0:
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
finally:
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_grpo(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
|
end_token: str = "</answer>",
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
batch_size: int = 1
|
batch_size: int = 1,
|
||||||
):
|
):
|
||||||
"""Generate completions without tracking gradients"""
|
try:
|
||||||
|
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||||
# Store original state
|
|
||||||
was_training = model.training
|
|
||||||
|
|
||||||
# Force eval mode
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Prepare prompts
|
|
||||||
total_samples = len(prompt_tokens)
|
total_samples = len(prompt_tokens)
|
||||||
all_completions = []
|
all_completions = []
|
||||||
all_completion_texts = []
|
all_completion_texts = []
|
||||||
batch_indices = []
|
batch_indices = []
|
||||||
|
|
||||||
# Process in smaller batches
|
def temp_sampler(logits):
|
||||||
|
return mx.random.categorical(logits / temperature)
|
||||||
|
|
||||||
for i in range(0, total_samples, batch_size):
|
for i in range(0, total_samples, batch_size):
|
||||||
current_batch_size = min(batch_size, total_samples - i)
|
current_batch_size = min(batch_size, total_samples - i)
|
||||||
batch_prompts = prompt_tokens[i : i + current_batch_size]
|
batch_prompts = prompt_tokens[i : i + current_batch_size]
|
||||||
|
|
||||||
# Pad sequences to the same length
|
|
||||||
max_prompt_len = max(len(p) for p in batch_prompts)
|
max_prompt_len = max(len(p) for p in batch_prompts)
|
||||||
padded_prompts = []
|
padded_prompts = []
|
||||||
|
|
||||||
for prompt in batch_prompts:
|
for prompt in batch_prompts:
|
||||||
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
|
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
|
||||||
padded_prompts.append(prompt + padding)
|
padded_prompts.append(prompt + padding)
|
||||||
|
|
||||||
# Convert to tensor and explicitly stop gradient
|
|
||||||
prompt_tensor = mx.stop_gradient(mx.array(padded_prompts))
|
prompt_tensor = mx.stop_gradient(mx.array(padded_prompts))
|
||||||
|
|
||||||
try:
|
if len(prompt_tensor.shape) == 1:
|
||||||
completions = generate_grpo(
|
prompt_tensor = prompt_tensor[None, :]
|
||||||
|
if prompt_tensor.shape[1] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
expanded_prompts = mx.repeat(prompt_tensor, group_size, axis=0)
|
||||||
|
batch_results = []
|
||||||
|
|
||||||
|
total_prompt_samples = expanded_prompts.shape[0]
|
||||||
|
for prompt_idx in range(total_prompt_samples):
|
||||||
|
current_tokens = []
|
||||||
|
prompt_cache = cache.make_prompt_cache(model)
|
||||||
|
|
||||||
|
for token, _ in generate_step(
|
||||||
|
expanded_prompts[prompt_idx],
|
||||||
model,
|
model,
|
||||||
prompt_tensor,
|
max_tokens=max_tokens,
|
||||||
max_tokens,
|
sampler=temp_sampler,
|
||||||
tokenizer,
|
prompt_cache=prompt_cache,
|
||||||
group_size,
|
):
|
||||||
temperature=temperature,
|
if token == tokenizer.eos_token_id:
|
||||||
batch_size=current_batch_size,
|
break
|
||||||
)
|
|
||||||
|
|
||||||
if completions is not None:
|
current_tokens.append(token)
|
||||||
for j, completion_ids in enumerate(completions):
|
|
||||||
|
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
||||||
|
mx.array(current_tokens[-len(end_sequence):]), end_sequence
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
if current_tokens:
|
||||||
|
batch_results.append(mx.array(current_tokens))
|
||||||
|
|
||||||
|
if batch_results:
|
||||||
|
for j, completion_ids in enumerate(batch_results):
|
||||||
prompt_idx = i + (j // group_size)
|
prompt_idx = i + (j // group_size)
|
||||||
|
|
||||||
if prompt_idx < total_samples:
|
if prompt_idx < total_samples:
|
||||||
batch_indices.append(prompt_idx)
|
batch_indices.append(prompt_idx)
|
||||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||||
all_completions.append(completion_ids)
|
all_completions.append(mx.stop_gradient(completion_ids))
|
||||||
all_completion_texts.append(completion_text)
|
all_completion_texts.append(completion_text)
|
||||||
mx.eval(completion_ids)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Generation error: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Restore original state
|
mx.metal.clear_cache()
|
||||||
if was_training:
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
|
finally:
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
return all_completions, all_completion_texts, batch_indices
|
return all_completions, all_completion_texts, batch_indices
|
||||||
@ -375,6 +202,9 @@ def grpo_loss(
|
|||||||
ref_model,
|
ref_model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
batch,
|
batch,
|
||||||
|
completions=None,
|
||||||
|
completion_texts=None,
|
||||||
|
batch_indices=None,
|
||||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||||
beta: float = 0.1,
|
beta: float = 0.1,
|
||||||
group_size: int = 4,
|
group_size: int = 4,
|
||||||
@ -387,8 +217,12 @@ def grpo_loss(
|
|||||||
):
|
):
|
||||||
prompt_tokens, _, prompt_text, answer_text = batch
|
prompt_tokens, _, prompt_text, answer_text = batch
|
||||||
|
|
||||||
# Generate completions without tracking gradients
|
if completions is not None and completion_texts is not None and batch_indices is not None:
|
||||||
all_completions, all_completion_texts, batch_indices = generate_without_gradients(
|
all_completions = completions
|
||||||
|
all_completion_texts = completion_texts
|
||||||
|
batch_indices = batch_indices
|
||||||
|
else:
|
||||||
|
all_completions, all_completion_texts, batch_indices = generate_grpo(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
@ -398,24 +232,20 @@ def grpo_loss(
|
|||||||
batch_size=batch_size
|
batch_size=batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we didn't generate any completions, return early
|
|
||||||
if not all_completions:
|
if not all_completions:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No completions were generated. Please check your model and inputs."
|
"No completions were generated. Please check your model and inputs."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create expanded prompts and answers based on actual generated completions
|
|
||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
expanded_prompts = []
|
expanded_prompts = []
|
||||||
|
|
||||||
# Group completions by their original prompt
|
|
||||||
unique_prompt_indices = sorted(set(batch_indices))
|
unique_prompt_indices = sorted(set(batch_indices))
|
||||||
grouped_completions = {idx: [] for idx in unique_prompt_indices}
|
grouped_completions = {idx: [] for idx in unique_prompt_indices}
|
||||||
|
|
||||||
for i, completion_idx in enumerate(batch_indices):
|
for i, completion_idx in enumerate(batch_indices):
|
||||||
grouped_completions[completion_idx].append(i)
|
grouped_completions[completion_idx].append(i)
|
||||||
|
|
||||||
# Rebuild completions in the correct order
|
|
||||||
ordered_completions = []
|
ordered_completions = []
|
||||||
ordered_completion_texts = []
|
ordered_completion_texts = []
|
||||||
ordered_batch_indices = []
|
ordered_batch_indices = []
|
||||||
@ -426,8 +256,6 @@ def grpo_loss(
|
|||||||
ordered_completions.append(all_completions[idx])
|
ordered_completions.append(all_completions[idx])
|
||||||
ordered_completion_texts.append(all_completion_texts[idx])
|
ordered_completion_texts.append(all_completion_texts[idx])
|
||||||
ordered_batch_indices.append(prompt_idx)
|
ordered_batch_indices.append(prompt_idx)
|
||||||
|
|
||||||
# Add corresponding prompt and answer
|
|
||||||
expanded_prompts.append(prompt_text[prompt_idx])
|
expanded_prompts.append(prompt_text[prompt_idx])
|
||||||
expanded_answers.append(answer_text[prompt_idx])
|
expanded_answers.append(answer_text[prompt_idx])
|
||||||
|
|
||||||
@ -435,14 +263,11 @@ def grpo_loss(
|
|||||||
all_completion_texts = ordered_completion_texts
|
all_completion_texts = ordered_completion_texts
|
||||||
batch_indices = ordered_batch_indices
|
batch_indices = ordered_batch_indices
|
||||||
|
|
||||||
# Create new input tensors for the model to compute logits with gradient tracking
|
|
||||||
max_length = max(ids.shape[0] for ids in all_completions)
|
max_length = max(ids.shape[0] for ids in all_completions)
|
||||||
padded_completions = []
|
padded_completions = []
|
||||||
attention_masks = []
|
attention_masks = []
|
||||||
|
|
||||||
for completion_ids in all_completions:
|
for completion_ids in all_completions:
|
||||||
# Convert the pre-generated completion to a regular tensor (not stop_gradient)
|
|
||||||
# This allows gradients to flow during the loss computation phase
|
|
||||||
completion_tensor = mx.array(completion_ids.tolist())
|
completion_tensor = mx.array(completion_ids.tolist())
|
||||||
|
|
||||||
padding_length = max_length - completion_tensor.shape[0]
|
padding_length = max_length - completion_tensor.shape[0]
|
||||||
@ -458,12 +283,10 @@ def grpo_loss(
|
|||||||
padded_completions.append(padded_ids)
|
padded_completions.append(padded_ids)
|
||||||
attention_masks.append(mask)
|
attention_masks.append(mask)
|
||||||
|
|
||||||
# Rest of the function remains the same
|
|
||||||
inputs = mx.stack(padded_completions)
|
inputs = mx.stack(padded_completions)
|
||||||
attention_mask = mx.stack(attention_masks)
|
attention_mask = mx.stack(attention_masks)
|
||||||
lengths = attention_mask.sum(axis=1)
|
lengths = attention_mask.sum(axis=1)
|
||||||
|
|
||||||
# Current policy probabilities
|
|
||||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||||
mx.eval(token_log_probs)
|
mx.eval(token_log_probs)
|
||||||
|
|
||||||
@ -487,10 +310,8 @@ def grpo_loss(
|
|||||||
token_log_probs = mx.stack(padded_log_probs)
|
token_log_probs = mx.stack(padded_log_probs)
|
||||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||||
|
|
||||||
# Create array to store rewards from each function
|
|
||||||
all_func_rewards = []
|
all_func_rewards = []
|
||||||
|
|
||||||
# Collect rewards from each function separately
|
|
||||||
for reward_func in reward_funcs:
|
for reward_func in reward_funcs:
|
||||||
func_rewards = mx.array(
|
func_rewards = mx.array(
|
||||||
reward_func(
|
reward_func(
|
||||||
@ -501,10 +322,8 @@ def grpo_loss(
|
|||||||
)
|
)
|
||||||
all_func_rewards.append(func_rewards)
|
all_func_rewards.append(func_rewards)
|
||||||
|
|
||||||
# Stack rewards to shape (num_samples, num_funcs)
|
|
||||||
rewards = mx.stack(all_func_rewards, axis=1)
|
rewards = mx.stack(all_func_rewards, axis=1)
|
||||||
|
|
||||||
# Apply weights and sum
|
|
||||||
if reward_weights is not None:
|
if reward_weights is not None:
|
||||||
if len(reward_weights) != len(reward_funcs):
|
if len(reward_weights) != len(reward_funcs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -517,24 +336,19 @@ def grpo_loss(
|
|||||||
|
|
||||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||||
|
|
||||||
# Get number of unique prompts
|
|
||||||
num_unique_prompts = len(unique_prompt_indices)
|
num_unique_prompts = len(unique_prompt_indices)
|
||||||
|
|
||||||
# Reshape rewards based on actual groups
|
|
||||||
rewards_by_prompt = [[] for _ in range(num_unique_prompts)]
|
rewards_by_prompt = [[] for _ in range(num_unique_prompts)]
|
||||||
for i, prompt_idx in enumerate(batch_indices):
|
for i, prompt_idx in enumerate(batch_indices):
|
||||||
prompt_position = unique_prompt_indices.index(prompt_idx)
|
prompt_position = unique_prompt_indices.index(prompt_idx)
|
||||||
rewards_by_prompt[prompt_position].append(rewards[i])
|
rewards_by_prompt[prompt_position].append(rewards[i])
|
||||||
|
|
||||||
# Calculate advantages within each group
|
|
||||||
advantages = mx.zeros_like(rewards)
|
advantages = mx.zeros_like(rewards)
|
||||||
for i, prompt_rewards in enumerate(rewards_by_prompt):
|
for i, prompt_rewards in enumerate(rewards_by_prompt):
|
||||||
if len(prompt_rewards) > 1: # Only normalize if we have multiple samples
|
if len(prompt_rewards) > 1:
|
||||||
prompt_rewards = mx.array(prompt_rewards)
|
prompt_rewards = mx.array(prompt_rewards)
|
||||||
mean_reward = mx.mean(prompt_rewards)
|
mean_reward = mx.mean(prompt_rewards)
|
||||||
std_reward = mx.std(prompt_rewards)
|
std_reward = mx.std(prompt_rewards)
|
||||||
|
|
||||||
# Find indices for this prompt
|
|
||||||
indices = [
|
indices = [
|
||||||
j
|
j
|
||||||
for j, idx in enumerate(batch_indices)
|
for j, idx in enumerate(batch_indices)
|
||||||
@ -545,7 +359,6 @@ def grpo_loss(
|
|||||||
std_reward + epsilon
|
std_reward + epsilon
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If only one sample, advantage is 0
|
|
||||||
idx = batch_indices.index(unique_prompt_indices[i])
|
idx = batch_indices.index(unique_prompt_indices[i])
|
||||||
advantages[idx] = 0.0
|
advantages[idx] = 0.0
|
||||||
|
|
||||||
@ -746,6 +559,7 @@ def evaluate_grpo(
|
|||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
is_validation=True
|
||||||
)
|
)
|
||||||
|
|
||||||
all_losses += losses * toks
|
all_losses += losses * toks
|
||||||
@ -803,21 +617,37 @@ def train_grpo(
|
|||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
def step(batch):
|
def step(batch):
|
||||||
|
# Extract prompt tokens from the batch
|
||||||
|
prompt_tokens, targets, prompt_lens, target_lens = batch
|
||||||
|
|
||||||
|
# First, generate completions without gradient tracking
|
||||||
|
# The model will be frozen during this call
|
||||||
|
all_completions, all_completion_texts, batch_indices = generate_grpo(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
max_tokens=args.max_completion_length,
|
||||||
|
group_size=args.group_size,
|
||||||
|
temperature=args.temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now calculate loss and gradients with pre-generated completions
|
||||||
|
# We need to update loss_fn to accept these pre-generated completions
|
||||||
(loss, toks, metrics), grad = loss_value_and_grad(
|
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||||
model,
|
model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch=batch,
|
batch=(prompt_tokens, targets, prompt_lens, target_lens),
|
||||||
|
completions=all_completions,
|
||||||
|
completion_texts=all_completion_texts,
|
||||||
|
batch_indices=batch_indices,
|
||||||
reward_funcs=reward_funcs,
|
reward_funcs=reward_funcs,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
group_size=args.group_size,
|
group_size=args.group_size,
|
||||||
epsilon=args.epsilon,
|
epsilon=args.epsilon,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
max_tokens=args.max_completion_length,
|
|
||||||
temperature=args.temperature,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
grad = average_gradients(grad)
|
grad = average_gradients(grad)
|
||||||
|
|
||||||
optimizer.update(model, grad)
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
return loss, toks, metrics
|
return loss, toks, metrics
|
||||||
|
Loading…
Reference in New Issue
Block a user