mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
training mode working too got from 2 toks/sec to 30 toks/sec with raw 1.5B model
This commit is contained in:
parent
6086137131
commit
710bc1490e
@ -112,51 +112,90 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
||||
return scores
|
||||
|
||||
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size):
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False):
|
||||
if len(prompts.shape) == 1:
|
||||
prompts = prompts[None, :]
|
||||
if prompts.shape[1] == 0:
|
||||
return None
|
||||
|
||||
model.eval()
|
||||
batch_size = prompts.shape[0] * group_size
|
||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||
mx.eval(expanded_prompts)
|
||||
|
||||
results = []
|
||||
tokens_generated = 0
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
for idx in range(batch_size):
|
||||
current_tokens = []
|
||||
generator = generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||
)
|
||||
for idx in range(batch_size):
|
||||
current_prompt = expanded_prompts[idx:idx+1]
|
||||
mx.eval(current_prompt)
|
||||
|
||||
current_tokens = []
|
||||
try:
|
||||
if is_training:
|
||||
# Initialize with prompt
|
||||
current_input = current_prompt[0]
|
||||
mx.eval(current_input)
|
||||
|
||||
while len(current_tokens) < max_tokens:
|
||||
# Generate one token at a time
|
||||
logits = model(current_input[None])
|
||||
next_token = mx.random.categorical(logits[:, -1, :])
|
||||
token = next_token.item()
|
||||
current_tokens.append(token)
|
||||
tokens_generated += 1
|
||||
|
||||
# Clear intermediate results
|
||||
mx.eval(next_token)
|
||||
del logits
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
# Update input for next iteration
|
||||
current_input = mx.array([token])
|
||||
mx.eval(current_input)
|
||||
|
||||
# Clear cache periodically
|
||||
if len(current_tokens) % 8 == 0:
|
||||
mx.metal.clear_cache()
|
||||
else:
|
||||
generator = generate_step(
|
||||
current_prompt[0],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.random.categorical(x)
|
||||
)
|
||||
|
||||
for token, _ in generator:
|
||||
current_tokens.append(token)
|
||||
tokens_generated += 1
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
# Collect all tokens first
|
||||
for tokens, _ in generator:
|
||||
current_tokens.append(tokens)
|
||||
tokens_generated += 1
|
||||
if tokens == tokenizer.eos_token_id:
|
||||
break
|
||||
if current_tokens:
|
||||
token_array = mx.array(current_tokens)
|
||||
mx.eval(token_array)
|
||||
results.append(token_array)
|
||||
del token_array
|
||||
|
||||
# Convert to array after collection
|
||||
results.append(mx.array(current_tokens))
|
||||
mx.metal.clear_cache()
|
||||
except Exception as e:
|
||||
print(f"Generation failed for sequence {idx}: {e}")
|
||||
continue
|
||||
|
||||
# Final evaluation of all results
|
||||
mx.eval(results)
|
||||
generation_time = time.perf_counter() - start_time
|
||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
||||
return results
|
||||
mx.metal.clear_cache()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Generation error: {str(e)}")
|
||||
if not results:
|
||||
print("No successful generations")
|
||||
return None
|
||||
|
||||
mx.eval(results)
|
||||
|
||||
generation_time = time.perf_counter() - start_time
|
||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
logits = model(inputs).astype(mx.float16)
|
||||
@ -209,7 +248,8 @@ def grpo_loss(
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size
|
||||
group_size,
|
||||
True
|
||||
)
|
||||
|
||||
if completions is not None:
|
||||
@ -221,6 +261,8 @@ def grpo_loss(
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
|
Loading…
Reference in New Issue
Block a user