mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 11:21:32 +08:00
updates
This commit is contained in:
parent
2d2f39f96e
commit
326935be49
@ -175,8 +175,6 @@ def generate_grpo(
|
|||||||
try:
|
try:
|
||||||
import time
|
import time
|
||||||
|
|
||||||
model.freeze()
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if len(prompts.shape) == 1:
|
if len(prompts.shape) == 1:
|
||||||
@ -213,7 +211,6 @@ def generate_grpo(
|
|||||||
sample_start_time = time.time()
|
sample_start_time = time.time()
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
prompt_cache = cache.make_prompt_cache(model)
|
prompt_cache = cache.make_prompt_cache(model)
|
||||||
mx.eval(current_tokens, prompt_cache)
|
|
||||||
|
|
||||||
# The generate_step function yields one token at a time
|
# The generate_step function yields one token at a time
|
||||||
# We'll collect tokens until we hit max_tokens or a stopping condition
|
# We'll collect tokens until we hit max_tokens or a stopping condition
|
||||||
@ -226,14 +223,13 @@ def generate_grpo(
|
|||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
print(token)
|
|
||||||
|
|
||||||
# Check for EOS token
|
# Check for EOS token
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
mx.eval(current_tokens[-1])
|
|
||||||
|
print(token)
|
||||||
|
|
||||||
# Check for end token
|
# Check for end token
|
||||||
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
||||||
@ -245,6 +241,8 @@ def generate_grpo(
|
|||||||
if i >= max_tokens - 1:
|
if i >= max_tokens - 1:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
mx.eval(current_tokens)
|
||||||
|
|
||||||
if current_tokens:
|
if current_tokens:
|
||||||
results.append(mx.array(current_tokens))
|
results.append(mx.array(current_tokens))
|
||||||
total_tokens_generated += len(current_tokens)
|
total_tokens_generated += len(current_tokens)
|
||||||
@ -273,7 +271,6 @@ def generate_grpo(
|
|||||||
|
|
||||||
results = [mx.stop_gradient(r) for r in results]
|
results = [mx.stop_gradient(r) for r in results]
|
||||||
mx.eval(results)
|
mx.eval(results)
|
||||||
model.unfreeze()
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -885,6 +882,8 @@ def train_grpo(
|
|||||||
n_tokens += toks
|
n_tokens += toks
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
accumulated_metrics[k] += v
|
accumulated_metrics[k] += v
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user