mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
Merge branch 'main' into adding-GRPO-training
This commit is contained in:
commit
80e10a59d7
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.21.5"
|
__version__ = "0.21.6"
|
||||||
|
@ -65,12 +65,25 @@ def main():
|
|||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
def print_help():
|
||||||
|
print("The command list:")
|
||||||
|
print("- 'q' to exit")
|
||||||
|
print("- 'r' to reset the chat")
|
||||||
|
print("- 'h' to display these commands")
|
||||||
|
|
||||||
|
print(f"[INFO] Starting chat session with {args.model}.")
|
||||||
|
print_help()
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
if query == "q":
|
if query == "q":
|
||||||
break
|
break
|
||||||
|
if query == "r":
|
||||||
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
|
continue
|
||||||
|
if query == "h":
|
||||||
|
print_help()
|
||||||
|
continue
|
||||||
messages = [{"role": "user", "content": query}]
|
messages = [{"role": "user", "content": query}]
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
for response in stream_generate(
|
for response in stream_generate(
|
||||||
|
@ -64,6 +64,7 @@ CONFIG_DEFAULTS = {
|
|||||||
"grad_checkpoint": False,
|
"grad_checkpoint": False,
|
||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
|
"mask_prompt": False,
|
||||||
|
|
||||||
# GRPO args
|
# GRPO args
|
||||||
"reference_model_path": None,
|
"reference_model_path": None,
|
||||||
@ -74,7 +75,7 @@ CONFIG_DEFAULTS = {
|
|||||||
"use_chat_template": False,
|
"use_chat_template": False,
|
||||||
"use_prompt": False,
|
"use_prompt": False,
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"reward_weights": None,
|
"reward_weights": None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -112,7 +113,7 @@ def build_parser():
|
|||||||
"--mask-prompt",
|
"--mask-prompt",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Mask the prompt in the loss when training",
|
help="Mask the prompt in the loss when training",
|
||||||
default=False,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -181,6 +181,7 @@ class DeepseekV3Attention(nn.Module):
|
|||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.rope_scaling is not None:
|
||||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||||
scaling_factor = self.config.rope_scaling["factor"]
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
if mscale_all_dim:
|
if mscale_all_dim:
|
||||||
@ -205,6 +206,12 @@ class DeepseekV3Attention(nn.Module):
|
|||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
**rope_kwargs,
|
**rope_kwargs,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
dims=self.qk_rope_head_dim,
|
||||||
|
base=self.rope_theta,
|
||||||
|
traditional=True,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -487,8 +494,12 @@ class Model(nn.Module):
|
|||||||
]
|
]
|
||||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||||
|
|
||||||
# Remove multi-token prediction layer
|
# Remove multi-token prediction layer and any unused precomputed rotary freqs
|
||||||
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in weights.items()
|
||||||
|
if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
|
@ -191,6 +191,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
"*.py",
|
"*.py",
|
||||||
"tokenizer.model",
|
"tokenizer.model",
|
||||||
"*.tiktoken",
|
"*.tiktoken",
|
||||||
|
"tiktoken.model",
|
||||||
"*.txt",
|
"*.txt",
|
||||||
"*.jsonl",
|
"*.jsonl",
|
||||||
],
|
],
|
||||||
|
Loading…
Reference in New Issue
Block a user