mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 04:31:13 +08:00
removing comments + adding temperature + reward weighting
This commit is contained in:
parent
baeb9f117f
commit
5ec4790656
@ -1,21 +1,21 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
import argparse
|
import argparse
|
||||||
|
import types
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import types
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
|
||||||
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tokenizer_utils import TokenizerWrapper
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
from .tuner.datasets import load_dataset
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
|
||||||
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
|
|
||||||
from .tuner.utils import (
|
from .tuner.utils import (
|
||||||
build_schedule,
|
build_schedule,
|
||||||
linear_to_lora_layers,
|
linear_to_lora_layers,
|
||||||
@ -73,6 +73,8 @@ CONFIG_DEFAULTS = {
|
|||||||
"max_completion_length": 512,
|
"max_completion_length": 512,
|
||||||
"use_chat_template": False,
|
"use_chat_template": False,
|
||||||
"use_prompt": False,
|
"use_prompt": False,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"reward_weights": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -224,6 +226,18 @@ def build_parser():
|
|||||||
help="Rather to use the prompt from the R1 paper.",
|
help="Rather to use the prompt from the R1 paper.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
help="Temperature for sampling. The higher the temperature, the more random the completions.",
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reward-weights",
|
||||||
|
type=str,
|
||||||
|
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
|
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
|
||||||
@ -241,7 +255,9 @@ def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_
|
|||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
group_size=args.group_size,
|
group_size=args.group_size,
|
||||||
epsilon=args.epsilon,
|
epsilon=args.epsilon,
|
||||||
reference_model_path=args.reference_model_path
|
reference_model_path=args.reference_model_path,
|
||||||
|
temperature=args.temperature,
|
||||||
|
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.reference_model_path:
|
if args.reference_model_path:
|
||||||
|
@ -35,13 +35,24 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
"help": "Path to reference model weights. If None, uses the same model."
|
"help": "Path to reference model weights. If None, uses the same model."
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
temperature: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={
|
||||||
|
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
reward_weights: Optional[List[float]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
||||||
|
|
||||||
|
|
||||||
def r1_extract_xml_answer(text: str) -> str:
|
def r1_extract_xml_answer(text: str) -> str:
|
||||||
"""Extracts the answer from an XML formatted text string."""
|
|
||||||
try:
|
try:
|
||||||
answer = text.split("<answer>")[-1]
|
answer = text.split("<answer>")[-1]
|
||||||
answer = answer.split("</answer>")[0]
|
answer = answer.split("</answer>")[0]
|
||||||
@ -52,14 +63,12 @@ def r1_extract_xml_answer(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||||
|
|
||||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
|
||||||
if not completions or not answer:
|
if not completions or not answer:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
@ -67,7 +76,6 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
|
|||||||
|
|
||||||
|
|
||||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||||
@ -76,7 +84,6 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
|
|||||||
|
|
||||||
|
|
||||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||||
@ -85,16 +92,13 @@ def r1_strict_format_reward_func(prompts: list, completions: list, answer: list,
|
|||||||
|
|
||||||
|
|
||||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
for text in completions:
|
for text in completions:
|
||||||
if not text:
|
if not text:
|
||||||
scores.append(0.0)
|
scores.append(0.0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
count = 0.0
|
count = 0.0
|
||||||
if text.count("<think>\n") == 1:
|
if text.count("<think>\n") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
@ -104,13 +108,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
count += 0.125
|
count += 0.125
|
||||||
if text.count("\n</answer>\n") == 1:
|
if text.count("\n</answer>\n") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
|
|
||||||
# Penalize extra text after </answer>
|
|
||||||
end_text = text.split("\n</answer>\n")[-1]
|
end_text = text.split("\n</answer>\n")[-1]
|
||||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||||
|
scores.append(max(0.0, count))
|
||||||
scores.append(max(0.0, count)) # Ensure non-negative score
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@ -119,22 +119,18 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
|
|||||||
prompt = prompt[None, :]
|
prompt = prompt[None, :]
|
||||||
if prompt.shape[1] == 0:
|
if prompt.shape[1] == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
end_sequence = tokenizer.encode("</answer>")
|
end_sequence = tokenizer.encode("</answer>")
|
||||||
end_sequence_length = len(end_sequence)
|
end_sequence_length = len(end_sequence)
|
||||||
|
|
||||||
initial_length = prompt.shape[1]
|
initial_length = prompt.shape[1]
|
||||||
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
|
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
|
||||||
output[:initial_length] = prompt[0]
|
output[:initial_length] = prompt[0]
|
||||||
current_length = initial_length
|
current_length = initial_length
|
||||||
|
|
||||||
try:
|
try:
|
||||||
def sample(logits):
|
def sample(logits):
|
||||||
if temperature > 0:
|
if temperature > 0:
|
||||||
logits /= temperature
|
logits /= temperature
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
|
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
current_input = output[:current_length][None, :]
|
current_input = output[:current_length][None, :]
|
||||||
logits = model(current_input)
|
logits = model(current_input)
|
||||||
@ -143,18 +139,14 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
|
|||||||
token_value = next_token.item()
|
token_value = next_token.item()
|
||||||
output[current_length] = token_value
|
output[current_length] = token_value
|
||||||
current_length += 1
|
current_length += 1
|
||||||
|
|
||||||
if token_value == tokenizer.eos_token_id:
|
if token_value == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
if current_length >= end_sequence_length:
|
if current_length >= end_sequence_length:
|
||||||
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
||||||
if last_tokens == end_sequence:
|
if last_tokens == end_sequence:
|
||||||
break
|
break
|
||||||
|
|
||||||
if current_length > initial_length:
|
if current_length > initial_length:
|
||||||
return output[:current_length]
|
return output[:current_length]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Generation error: {str(e)}")
|
print(f"Generation error: {str(e)}")
|
||||||
return None
|
return None
|
||||||
@ -192,9 +184,10 @@ def grpo_loss(
|
|||||||
group_size=4,
|
group_size=4,
|
||||||
epsilon=1e-4,
|
epsilon=1e-4,
|
||||||
max_tokens=64,
|
max_tokens=64,
|
||||||
temperature=1.0
|
temperature=1.0,
|
||||||
|
reward_weights=None
|
||||||
):
|
):
|
||||||
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
prompt_tokens, _, prompt_text, answer_text = batch
|
||||||
batch_size = len(prompt_tokens)
|
batch_size = len(prompt_tokens)
|
||||||
|
|
||||||
all_completions = []
|
all_completions = []
|
||||||
@ -273,18 +266,34 @@ def grpo_loss(
|
|||||||
token_log_probs = mx.stack(padded_log_probs)
|
token_log_probs = mx.stack(padded_log_probs)
|
||||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||||
|
|
||||||
# Rewards and advantages
|
# Create array to store rewards from each function
|
||||||
rewards = mx.zeros((len(all_completions),))
|
all_func_rewards = []
|
||||||
|
|
||||||
|
# Collect rewards from each function separately
|
||||||
for reward_func in reward_funcs:
|
for reward_func in reward_funcs:
|
||||||
func_rewards = mx.array(reward_func(
|
func_rewards = mx.array(reward_func(
|
||||||
prompts=expanded_prompts,
|
prompts=expanded_prompts,
|
||||||
completions=all_completion_texts,
|
completions=all_completion_texts,
|
||||||
answer=expanded_answers
|
answer=expanded_answers
|
||||||
))
|
))
|
||||||
rewards += func_rewards
|
all_func_rewards.append(func_rewards)
|
||||||
|
|
||||||
if len(reward_funcs) > 1:
|
# Stack rewards to shape (num_samples, num_funcs)
|
||||||
rewards /= len(reward_funcs)
|
rewards = mx.stack(all_func_rewards, axis=1)
|
||||||
|
print(f"Rewards: {rewards}")
|
||||||
|
|
||||||
|
# Apply weights and sum
|
||||||
|
if reward_weights is not None:
|
||||||
|
if len(reward_weights) != len(reward_funcs):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
|
||||||
|
f"functions ({len(reward_funcs)})"
|
||||||
|
)
|
||||||
|
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
||||||
|
else:
|
||||||
|
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
||||||
|
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||||
|
print(f"Rewards after weights: {rewards}")
|
||||||
|
|
||||||
# Reshape rewards and compute advantages
|
# Reshape rewards and compute advantages
|
||||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||||
@ -397,15 +406,11 @@ def evaluate_grpo(
|
|||||||
epsilon: float,
|
epsilon: float,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
|
temperature: float,
|
||||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||||
loss_fn: callable = grpo_loss,
|
loss_fn: callable = grpo_loss,
|
||||||
iterate_batches: callable = iterate_grpo_batches
|
iterate_batches: callable = iterate_grpo_batches
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Evaluate model using GRPO loss.
|
|
||||||
Returns:
|
|
||||||
tuple: (average loss, number of tokens, average metrics)
|
|
||||||
"""
|
|
||||||
all_losses = 0
|
all_losses = 0
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
all_metrics = None
|
all_metrics = None
|
||||||
@ -428,7 +433,8 @@ def evaluate_grpo(
|
|||||||
beta=beta,
|
beta=beta,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
epsilon=epsilon,
|
epsilon=epsilon,
|
||||||
ref_model=ref_model
|
ref_model=ref_model,
|
||||||
|
temperature=temperature
|
||||||
)
|
)
|
||||||
|
|
||||||
all_losses += losses * toks
|
all_losses += losses * toks
|
||||||
@ -442,12 +448,10 @@ def evaluate_grpo(
|
|||||||
|
|
||||||
mx.eval(all_losses, ntokens)
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
# Aggregate across distributed workers
|
|
||||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||||
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||||
|
|
||||||
# Calculate averages
|
|
||||||
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
||||||
avg_loss = (all_losses / ntokens).item()
|
avg_loss = (all_losses / ntokens).item()
|
||||||
|
|
||||||
@ -486,8 +490,6 @@ def train_grpo(
|
|||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
def step(batch):
|
def step(batch):
|
||||||
|
|
||||||
# Forward and backward pass
|
|
||||||
(loss, toks, metrics), grad = loss_value_and_grad(
|
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||||
model,
|
model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -498,12 +500,11 @@ def train_grpo(
|
|||||||
epsilon=args.epsilon,
|
epsilon=args.epsilon,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
max_tokens=args.max_completion_length,
|
max_tokens=args.max_completion_length,
|
||||||
|
temperature=args.temperature
|
||||||
)
|
)
|
||||||
|
|
||||||
# All reduce the gradients if running in distributed mode
|
|
||||||
grad = average_gradients(grad)
|
grad = average_gradients(grad)
|
||||||
|
|
||||||
# Model update
|
|
||||||
optimizer.update(model, grad)
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
return loss, toks, metrics
|
return loss, toks, metrics
|
||||||
@ -536,8 +537,6 @@ def train_grpo(
|
|||||||
train=True,
|
train=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Report validation loss if needed, the first validation loss
|
|
||||||
# is always measured before any training.
|
|
||||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
||||||
@ -553,6 +552,7 @@ def train_grpo(
|
|||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
epsilon=args.epsilon,
|
epsilon=args.epsilon,
|
||||||
|
temperature=args.temperature,
|
||||||
iterate_batches=iterate_batches,
|
iterate_batches=iterate_batches,
|
||||||
)
|
)
|
||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - stop
|
||||||
@ -566,7 +566,6 @@ def train_grpo(
|
|||||||
f"Val kl {val_metrics['kl']:.3f}"
|
f"Val kl {val_metrics['kl']:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add reward function specific metrics
|
|
||||||
for i, reward_func in enumerate(reward_funcs):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
val_metrics_str += (
|
val_metrics_str += (
|
||||||
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
||||||
@ -622,7 +621,6 @@ def train_grpo(
|
|||||||
f"KL {avg_metrics['kl']:.3f}"
|
f"KL {avg_metrics['kl']:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add reward function specific metrics
|
|
||||||
for i, reward_func in enumerate(reward_funcs):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
func_name = reward_func.__name__
|
func_name = reward_func.__name__
|
||||||
train_metrics_str += (
|
train_metrics_str += (
|
||||||
@ -656,7 +654,6 @@ def train_grpo(
|
|||||||
steps = 0
|
steps = 0
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Save adapter weights
|
|
||||||
if it % args.steps_per_save == 0:
|
if it % args.steps_per_save == 0:
|
||||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||||
@ -669,7 +666,6 @@ def train_grpo(
|
|||||||
f"{args.adapter_file} and {checkpoint}."
|
f"{args.adapter_file} and {checkpoint}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save final weights
|
|
||||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||||
print(f"Saved final weights to {args.adapter_file}.")
|
print(f"Saved final weights to {args.adapter_file}.")
|
Loading…
Reference in New Issue
Block a user