Merge branch 'main' into adding-GRPO-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-28 21:16:02 +01:00 committed by GitHub
commit 80e10a59d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 55 additions and 29 deletions

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.21.5" __version__ = "0.21.6"

View File

@ -65,12 +65,25 @@ def main():
tokenizer_config={"trust_remote_code": True}, tokenizer_config={"trust_remote_code": True},
) )
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") def print_help():
print("The command list:")
print("- 'q' to exit")
print("- 'r' to reset the chat")
print("- 'h' to display these commands")
print(f"[INFO] Starting chat session with {args.model}.")
print_help()
prompt_cache = make_prompt_cache(model, args.max_kv_size) prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True: while True:
query = input(">> ") query = input(">> ")
if query == "q": if query == "q":
break break
if query == "r":
prompt_cache = make_prompt_cache(model, args.max_kv_size)
continue
if query == "h":
print_help()
continue
messages = [{"role": "user", "content": query}] messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate( for response in stream_generate(

View File

@ -64,6 +64,7 @@ CONFIG_DEFAULTS = {
"grad_checkpoint": False, "grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"mask_prompt": False,
# GRPO args # GRPO args
"reference_model_path": None, "reference_model_path": None,
@ -74,7 +75,7 @@ CONFIG_DEFAULTS = {
"use_chat_template": False, "use_chat_template": False,
"use_prompt": False, "use_prompt": False,
"temperature": 1.0, "temperature": 1.0,
"reward_weights": None, "reward_weights": None
} }
@ -112,7 +113,7 @@ def build_parser():
"--mask-prompt", "--mask-prompt",
action="store_true", action="store_true",
help="Mask the prompt in the loss when training", help="Mask the prompt in the loss when training",
default=False, default=None,
) )
parser.add_argument( parser.add_argument(

View File

@ -181,6 +181,7 @@ class DeepseekV3Attention(nn.Module):
bias=config.attention_bias, bias=config.attention_bias,
) )
if self.config.rope_scaling is not None:
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"] scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim: if mscale_all_dim:
@ -205,6 +206,12 @@ class DeepseekV3Attention(nn.Module):
base=self.rope_theta, base=self.rope_theta,
**rope_kwargs, **rope_kwargs,
) )
else:
self.rope = nn.RoPE(
dims=self.qk_rope_head_dim,
base=self.rope_theta,
traditional=True,
)
def __call__( def __call__(
self, self,
@ -487,8 +494,12 @@ class Model(nn.Module):
] ]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
# Remove multi-token prediction layer # Remove multi-token prediction layer and any unused precomputed rotary freqs
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")} return {
k: v
for k, v in weights.items()
if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k
}
@property @property
def layers(self): def layers(self):

View File

@ -191,6 +191,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
"*.py", "*.py",
"tokenizer.model", "tokenizer.model",
"*.tiktoken", "*.tiktoken",
"tiktoken.model",
"*.txt", "*.txt",
"*.jsonl", "*.jsonl",
], ],