training mode working too got from 2 toks/sec to 30 toks/sec with raw 1.5B model

This commit is contained in:
Goekdeniz-Guelmez 2025-02-21 22:42:15 +01:00
parent 6086137131
commit 710bc1490e

View File

@ -112,51 +112,90 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores return scores
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size): def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False):
if len(prompts.shape) == 1: if len(prompts.shape) == 1:
prompts = prompts[None, :] prompts = prompts[None, :]
if prompts.shape[1] == 0: if prompts.shape[1] == 0:
return None return None
model.eval()
batch_size = prompts.shape[0] * group_size batch_size = prompts.shape[0] * group_size
expanded_prompts = mx.repeat(prompts, group_size, axis=0) expanded_prompts = mx.repeat(prompts, group_size, axis=0)
mx.eval(expanded_prompts)
results = [] results = []
tokens_generated = 0 tokens_generated = 0
start_time = time.perf_counter() start_time = time.perf_counter()
try: for idx in range(batch_size):
for idx in range(batch_size): current_prompt = expanded_prompts[idx:idx+1]
current_tokens = [] mx.eval(current_prompt)
generator = generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens,
sampler=lambda x: mx.argmax(x, axis=-1)
)
# Collect all tokens first current_tokens = []
for tokens, _ in generator: try:
current_tokens.append(tokens) if is_training:
tokens_generated += 1 # Initialize with prompt
if tokens == tokenizer.eos_token_id: current_input = current_prompt[0]
break mx.eval(current_input)
# Convert to array after collection while len(current_tokens) < max_tokens:
results.append(mx.array(current_tokens)) # Generate one token at a time
mx.metal.clear_cache() logits = model(current_input[None])
next_token = mx.random.categorical(logits[:, -1, :])
token = next_token.item()
current_tokens.append(token)
tokens_generated += 1
# Final evaluation of all results # Clear intermediate results
mx.eval(results) mx.eval(next_token)
generation_time = time.perf_counter() - start_time del logits
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
return results
except Exception as e: if token == tokenizer.eos_token_id:
print(f"Generation error: {str(e)}") break
# Update input for next iteration
current_input = mx.array([token])
mx.eval(current_input)
# Clear cache periodically
if len(current_tokens) % 8 == 0:
mx.metal.clear_cache()
else:
generator = generate_step(
current_prompt[0],
model,
max_tokens=max_tokens,
sampler=lambda x: mx.random.categorical(x)
)
for token, _ in generator:
current_tokens.append(token)
tokens_generated += 1
if token == tokenizer.eos_token_id:
break
if current_tokens:
token_array = mx.array(current_tokens)
mx.eval(token_array)
results.append(token_array)
del token_array
except Exception as e:
print(f"Generation failed for sequence {idx}: {e}")
continue
mx.metal.clear_cache()
if not results:
print("No successful generations")
return None return None
mx.eval(results)
generation_time = time.perf_counter() - start_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
return results
def get_per_token_logps(model: nn.Module, inputs, lengths): def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16) logits = model(inputs).astype(mx.float16)
@ -209,7 +248,8 @@ def grpo_loss(
prompt_tensor, prompt_tensor,
max_tokens, max_tokens,
tokenizer, tokenizer,
group_size group_size,
True
) )
if completions is not None: if completions is not None:
@ -222,6 +262,8 @@ def grpo_loss(
print(f"Generation error: {e}") print(f"Generation error: {e}")
continue continue
mx.metal.clear_cache()
expanded_answers = [] expanded_answers = []
expanded_prompts = [] expanded_prompts = []
for i in range(batch_size): for i in range(batch_size):