mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
first succesfull training run
This commit is contained in:
parent
ca32424043
commit
7173840283
@ -63,11 +63,16 @@ CONFIG_DEFAULTS = {
|
|||||||
"config": None,
|
"config": None,
|
||||||
"grad_checkpoint": False,
|
"grad_checkpoint": False,
|
||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
|
|
||||||
|
# GRPO args
|
||||||
"reference_model_path": None,
|
"reference_model_path": None,
|
||||||
"group_size": 4,
|
"group_size": 4,
|
||||||
"beta": 0.1,
|
"beta": 0.1,
|
||||||
"epsilon": 1e-4,
|
"epsilon": 1e-4,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"max_completion_length": 512,
|
||||||
|
"use_chat_template": False,
|
||||||
|
"use_prompt": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -178,9 +183,15 @@ def build_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--group-size",
|
"--group-size",
|
||||||
type=int,
|
type=int,
|
||||||
help="Number of responses per prompt.",
|
help="Number of generations.",
|
||||||
default=4,
|
default=4,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-completion-length",
|
||||||
|
type=int,
|
||||||
|
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
|
||||||
|
default=512,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beta",
|
"--beta",
|
||||||
type=float,
|
type=float,
|
||||||
@ -193,6 +204,18 @@ def build_parser():
|
|||||||
help="The Epsilon for numerical stability.",
|
help="The Epsilon for numerical stability.",
|
||||||
default=1e-4,
|
default=1e-4,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-chat-template",
|
||||||
|
type=bool,
|
||||||
|
help="If the model is a Chat model, use the Chat template.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-prompt",
|
||||||
|
type=bool,
|
||||||
|
help="Rather to use the prompt from teh R1 paper.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -262,6 +285,7 @@ def train_model(
|
|||||||
steps_per_save=args.save_every,
|
steps_per_save=args.save_every,
|
||||||
adapter_file=adapter_file,
|
adapter_file=adapter_file,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
|
max_completion_length=args.max_completion_length,
|
||||||
grad_checkpoint=args.grad_checkpoint,
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
group_size=args.group_size,
|
group_size=args.group_size,
|
||||||
@ -273,7 +297,7 @@ def train_model(
|
|||||||
reference_model, _ = load(args.reference_model_path)
|
reference_model, _ = load(args.reference_model_path)
|
||||||
reference_model = reference_model.freeze()
|
reference_model = reference_model.freeze()
|
||||||
else:
|
else:
|
||||||
reference_model, _ = None, None
|
reference_model, _ = load(args.model)
|
||||||
|
|
||||||
train_grpo(
|
train_grpo(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -16,14 +16,33 @@ class GRPODataset:
|
|||||||
data: List[Dict[str, str]],
|
data: List[Dict[str, str]],
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_key: str = "prompt",
|
prompt_key: str = "prompt",
|
||||||
answer_key: str = "answer"
|
answer_key: str = "answer",
|
||||||
|
use_chat_template: bool = False,
|
||||||
|
use_prompt: bool = False
|
||||||
):
|
):
|
||||||
self._data = []
|
self._data = []
|
||||||
for item in data:
|
for item in data:
|
||||||
prompt_str = str(item[prompt_key])
|
prompt_str = str(item[prompt_key])
|
||||||
answer_str = str(item[answer_key])
|
answer_str = str(item[answer_key])
|
||||||
prompt_tokens = tokenizer.encode(prompt_str)
|
if use_chat_template:
|
||||||
answer_tokens = tokenizer.encode(answer_str)
|
prompt_tokens = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||||
|
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||||
|
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||||
|
{'role': 'user', 'content': prompt_str}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
answer_tokens = tokenizer.encode(answer_str)
|
||||||
|
else:
|
||||||
|
if use_prompt:
|
||||||
|
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||||
|
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||||
|
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
||||||
|
User: {prompt_str}. Assistant: """)
|
||||||
|
else:
|
||||||
|
prompt_tokens = tokenizer.encode(prompt_str)
|
||||||
|
answer_tokens = tokenizer.encode(answer_str)
|
||||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||||
|
@ -12,8 +12,6 @@ from mlx.utils import tree_flatten
|
|||||||
|
|
||||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||||
|
|
||||||
from mlx_lm.utils import generate_step
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GRPOTrainingArgs(TrainingArgs):
|
class GRPOTrainingArgs(TrainingArgs):
|
||||||
@ -27,6 +25,9 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
epsilon: float = field(
|
epsilon: float = field(
|
||||||
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
|
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
|
||||||
)
|
)
|
||||||
|
max_completion_length: int = field(
|
||||||
|
default=512, metadata={"help": "Number of Generations."}
|
||||||
|
)
|
||||||
reference_model_path: str = field(
|
reference_model_path: str = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@ -36,7 +37,6 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
|
|
||||||
|
|
||||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
||||||
model.eval()
|
|
||||||
if len(prompt.shape) == 1:
|
if len(prompt.shape) == 1:
|
||||||
prompt = prompt[None, :]
|
prompt = prompt[None, :]
|
||||||
|
|
||||||
@ -58,11 +58,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
|||||||
|
|
||||||
token_value = next_token.item()
|
token_value = next_token.item()
|
||||||
generated.append(next_token)
|
generated.append(next_token)
|
||||||
|
|
||||||
# Clear intermediate tensors
|
|
||||||
del logits, token_logits, probs
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
||||||
if token_value == tokenizer.eos_token_id:
|
if token_value == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
@ -72,12 +68,6 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
|||||||
|
|
||||||
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
||||||
mx.eval(result)
|
mx.eval(result)
|
||||||
model.train()
|
|
||||||
|
|
||||||
# Clear generated tokens
|
|
||||||
del generated
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -192,11 +182,6 @@ def get_per_token_logps(model, inputs, lengths):
|
|||||||
).squeeze(-1) # [seq_len]
|
).squeeze(-1) # [seq_len]
|
||||||
|
|
||||||
per_token_logps.append(token_log_probs)
|
per_token_logps.append(token_log_probs)
|
||||||
|
|
||||||
# Clean up intermediates
|
|
||||||
del seq_logits, seq_targets, log_probs, token_log_probs
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
return per_token_logps
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
@ -232,15 +217,9 @@ def grpo_loss(
|
|||||||
all_completions.append(completion_ids)
|
all_completions.append(completion_ids)
|
||||||
all_completion_texts.append(completion_text)
|
all_completion_texts.append(completion_text)
|
||||||
|
|
||||||
del completion_ids
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Generation error: {e}")
|
print(f"Generation error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
del prompt_tensor
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
@ -264,25 +243,11 @@ def grpo_loss(
|
|||||||
mask = mx.ones_like(completion_ids)
|
mask = mx.ones_like(completion_ids)
|
||||||
padded_completions.append(padded_ids)
|
padded_completions.append(padded_ids)
|
||||||
attention_masks.append(mask)
|
attention_masks.append(mask)
|
||||||
|
|
||||||
del completion_ids
|
|
||||||
if padding_length > 0:
|
|
||||||
del padding
|
|
||||||
del mask
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
inputs = mx.stack(padded_completions)
|
inputs = mx.stack(padded_completions)
|
||||||
attention_mask = mx.stack(attention_masks)
|
attention_mask = mx.stack(attention_masks)
|
||||||
lengths = attention_mask.sum(axis=1)
|
lengths = attention_mask.sum(axis=1)
|
||||||
|
|
||||||
del padded_completions, attention_masks
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
# Get logits and compute log probabilities
|
|
||||||
logits = model(inputs).astype(mx.float32)
|
|
||||||
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
|
|
||||||
targets = inputs[:, 1:]
|
|
||||||
|
|
||||||
# Current policy probabilities
|
# Current policy probabilities
|
||||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||||
|
|
||||||
@ -302,9 +267,6 @@ def grpo_loss(
|
|||||||
|
|
||||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||||
|
|
||||||
del padding
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
token_log_probs = mx.stack(padded_log_probs)
|
token_log_probs = mx.stack(padded_log_probs)
|
||||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||||
@ -360,10 +322,6 @@ def grpo_loss(
|
|||||||
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
||||||
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
||||||
|
|
||||||
# Clean up
|
|
||||||
del all_completions
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'total_rewards_mean': mx.mean(rewards),
|
'total_rewards_mean': mx.mean(rewards),
|
||||||
'total_rewards_std': mx.std(rewards),
|
'total_rewards_std': mx.std(rewards),
|
||||||
@ -440,7 +398,7 @@ def evaluate_grpo(
|
|||||||
group_size: int,
|
group_size: int,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
reward_funcs = None,
|
reward_funcs = None,
|
||||||
loss: callable = grpo_loss,
|
loss_fn: callable = grpo_loss,
|
||||||
iterate_batches: callable = iterate_grpo_batches
|
iterate_batches: callable = iterate_grpo_batches
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -466,7 +424,7 @@ def evaluate_grpo(
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Calculate loss for current batch
|
# Calculate loss for current batch
|
||||||
losses, toks, metrics = loss(
|
losses, toks, metrics = loss_fn(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
@ -518,7 +476,7 @@ def train_grpo(
|
|||||||
r1_count_xml
|
r1_count_xml
|
||||||
],
|
],
|
||||||
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
||||||
loss: callable = grpo_loss,
|
loss_fn: callable = grpo_loss,
|
||||||
iterate_batches: callable = iterate_grpo_batches,
|
iterate_batches: callable = iterate_grpo_batches,
|
||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
@ -546,7 +504,7 @@ def train_grpo(
|
|||||||
group_size=args.group_size,
|
group_size=args.group_size,
|
||||||
epsilon=args.epsilon,
|
epsilon=args.epsilon,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
max_tokens=args.max_seq_length,
|
max_tokens=args.max_completion_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# All reduce the gradients if running in distributed mode
|
# All reduce the gradients if running in distributed mode
|
||||||
@ -557,22 +515,23 @@ def train_grpo(
|
|||||||
|
|
||||||
return loss, toks, metrics
|
return loss, toks, metrics
|
||||||
|
|
||||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
|
||||||
|
|
||||||
losses = 0
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
trained_tokens = 0
|
trained_tokens = 0
|
||||||
accumulated_metrics = {
|
accumulated_metrics = {
|
||||||
'rewards': 0,
|
'total_rewards_mean': 0,
|
||||||
'rewards_std': 0,
|
'total_rewards_std': 0,
|
||||||
'grouped_rewards': 0,
|
'grouped_rewards_mean': 0,
|
||||||
'grouped_rewards_std': 0,
|
'grouped_rewards_std': 0,
|
||||||
'kl': 0
|
'kl': 0
|
||||||
}
|
}
|
||||||
for i in range(len(reward_funcs)):
|
for reward_func in reward_funcs:
|
||||||
accumulated_metrics[f'reward_func_{i}_mean'] = 0
|
func_name = reward_func.__name__
|
||||||
accumulated_metrics[f'reward_func_{i}_std'] = 0
|
accumulated_metrics[f'{func_name}_mean'] = 0
|
||||||
|
accumulated_metrics[f'{func_name}_std'] = 0
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
for it, batch in zip(
|
for it, batch in zip(
|
||||||
@ -592,7 +551,7 @@ def train_grpo(
|
|||||||
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
||||||
model=model,
|
model=model,
|
||||||
dataset=val_dataset,
|
dataset=val_dataset,
|
||||||
loss=loss,
|
loss_fn=loss_fn,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
reward_funcs=reward_funcs,
|
reward_funcs=reward_funcs,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -675,8 +634,8 @@ def train_grpo(
|
|||||||
for i, reward_func in enumerate(reward_funcs):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
func_name = reward_func.__name__
|
func_name = reward_func.__name__
|
||||||
train_metrics_str += (
|
train_metrics_str += (
|
||||||
f", Reward func {reward_func.__name__} mean {avg_metrics[f'reward_func_{reward_func.__name__}_mean']:.3f}, "
|
f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
|
||||||
f"Reward func {reward_func.__name__} std {avg_metrics[f'reward_func_{reward_func.__name__}_std']:.3f}"
|
f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
|
Loading…
Reference in New Issue
Block a user