Merge branch 'main' into adding-orpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-28 22:10:21 +01:00 committed by GitHub
commit 6a3912be7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 171 additions and 45 deletions

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.21.5" __version__ = "0.21.6"

View File

@ -65,12 +65,25 @@ def main():
tokenizer_config={"trust_remote_code": True}, tokenizer_config={"trust_remote_code": True},
) )
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") def print_help():
print("The command list:")
print("- 'q' to exit")
print("- 'r' to reset the chat")
print("- 'h' to display these commands")
print(f"[INFO] Starting chat session with {args.model}.")
print_help()
prompt_cache = make_prompt_cache(model, args.max_kv_size) prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True: while True:
query = input(">> ") query = input(">> ")
if query == "q": if query == "q":
break break
if query == "r":
prompt_cache = make_prompt_cache(model, args.max_kv_size)
continue
if query == "h":
print_help()
continue
messages = [{"role": "user", "content": query}] messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate( for response in stream_generate(

View File

@ -1,8 +1,27 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
from enum import Enum
from .utils import convert from .utils import convert, mixed_2_6, mixed_3_6
class MixedQuants(Enum):
mixed_3_6 = "mixed_3_6"
mixed_2_6 = "mixed_2_6"
@classmethod
def recipe_names(cls):
return [member.name for member in cls]
def quant_args(arg):
try:
return MixedQuants[arg].value
except KeyError:
raise argparse.ArgumentTypeError(
f"Invalid q-recipe {arg!r}. Choose from: {MixedQuants.recipe_names()}"
)
def configure_parser() -> argparse.ArgumentParser: def configure_parser() -> argparse.ArgumentParser:
@ -29,6 +48,12 @@ def configure_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4 "--q-bits", help="Bits per weight for quantization.", type=int, default=4
) )
parser.add_argument(
"--quant-predicate",
help=f"Mixed-bit quantization recipe. Choices: {MixedQuants.recipe_names()}",
type=quant_args,
required=False,
)
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
help="Type to save the non-quantized parameters.", help="Type to save the non-quantized parameters.",

View File

@ -289,17 +289,15 @@ class MLXLM(LM):
contexts, options = zip(*[req.args for req in requests]) contexts, options = zip(*[req.args for req in requests])
# contrary to the doc the second element of the tuple contains # contrary to the doc the second element of the tuple contains
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
keys = list(options[0].keys())
assert "until" in keys
untils = [x["until"] for x in options]
completions = [] completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)): for context, opt in tqdm(zip(contexts, options), total=len(contexts)):
until = opt["until"]
context = self.tokenizer.encode( context = self.tokenizer.encode(
context, add_special_tokens=not self.use_chat_template context, add_special_tokens=not self.use_chat_template
) )
max_tokens = min( max_tokens = min(
self._max_tokens, opt.get("max_gen_tokens", self._max_tokens),
self.tokenizer.model_max_length - len(context), self.tokenizer.model_max_length - len(context),
) )
text = "" text = ""
@ -334,9 +332,9 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--limit", "--limit",
default=1.0, default=100,
help="Limit the number of examples per task.", help="Limit the number of examples per task.",
type=float, type=int,
) )
parser.add_argument("--seed", type=int, default=123, help="Random seed.") parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument( parser.add_argument(

View File

@ -60,6 +60,11 @@ def setup_arg_parser():
default=DEFAULT_PROMPT, default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
parser.add_argument(
"--prefill-response",
default=None,
help="Prefill response to be used for the chat template",
)
parser.add_argument( parser.add_argument(
"--max-tokens", "--max-tokens",
"-m", "-m",
@ -219,10 +224,14 @@ def main():
messages = [] messages = []
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
has_prefill = args.prefill_response is not None
if has_prefill:
messages.append({"role": "assistant", "content": args.prefill_response})
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
**template_kwargs, **template_kwargs,
) )
@ -233,7 +242,8 @@ def main():
test_prompt = tokenizer.apply_chat_template( test_prompt = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
) )
prompt = prompt[test_prompt.index("<query>") :] prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False) prompt = tokenizer.encode(prompt, add_special_tokens=False)

View File

@ -64,6 +64,9 @@ CONFIG_DEFAULTS = {
"grad_checkpoint": False, "grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"mask_prompt": False,
# ORPO args
"beta": 0.1, "beta": 0.1,
"dpo_loss_type": "sigmoid", "dpo_loss_type": "sigmoid",
"delta": 50.0, "delta": 50.0,
@ -106,7 +109,7 @@ def build_parser():
"--mask-prompt", "--mask-prompt",
action="store_true", action="store_true",
help="Mask the prompt in the loss when training", help="Mask the prompt in the loss when training",
default=False, default=None,
) )
parser.add_argument( parser.add_argument(

View File

@ -2,7 +2,22 @@ import argparse
from typing import List, Union from typing import List, Union
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from transformers.commands.user import tabulate
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)
def ask_for_confirmation(message: str) -> bool: def ask_for_confirmation(message: str) -> bool:

View File

@ -181,30 +181,37 @@ class DeepseekV3Attention(nn.Module):
bias=config.attention_bias, bias=config.attention_bias,
) )
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) if self.config.rope_scaling is not None:
scaling_factor = self.config.rope_scaling["factor"] mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
if mscale_all_dim: scaling_factor = self.config.rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) if mscale_all_dim:
self.scale = self.scale * mscale * mscale mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = { rope_kwargs = {
key: self.config.rope_scaling[key] key: self.config.rope_scaling[key]
for key in [ for key in [
"original_max_position_embeddings", "original_max_position_embeddings",
"beta_fast", "beta_fast",
"beta_slow", "beta_slow",
"mscale", "mscale",
"mscale_all_dim", "mscale_all_dim",
] ]
if key in self.config.rope_scaling if key in self.config.rope_scaling
} }
self.rope = DeepseekV3YarnRotaryEmbedding( self.rope = DeepseekV3YarnRotaryEmbedding(
dim=self.qk_rope_head_dim, dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
base=self.rope_theta, base=self.rope_theta,
**rope_kwargs, **rope_kwargs,
) )
else:
self.rope = nn.RoPE(
dims=self.qk_rope_head_dim,
base=self.rope_theta,
traditional=True,
)
def __call__( def __call__(
self, self,
@ -487,8 +494,12 @@ class Model(nn.Module):
] ]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
# Remove multi-token prediction layer # Remove multi-token prediction layer and any unused precomputed rotary freqs
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")} return {
k: v
for k, v in weights.items()
if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k
}
@property @property
def layers(self): def layers(self):

View File

@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
partial_rotary_factor: float = 1.0
max_position_embeddings: int = 131072 max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096 original_max_position_embeddings: int = 4096
tie_word_embeddings: bool = False
def __post_init__(self): def __post_init__(self):
if self.num_key_value_heads is None: if self.num_key_value_heads is None:
@ -59,9 +61,10 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_dim = int(head_dim * args.partial_rotary_factor)
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding( self.rope = SuScaledRotaryEmbedding(
head_dim, rope_dim,
base=args.rope_theta, base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings, max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings, original_max_position_embeddings=args.original_max_position_embeddings,
@ -74,7 +77,7 @@ class Attention(nn.Module):
assert isinstance(args.rope_scaling["factor"], float) assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"] rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE( self.rope = nn.RoPE(
head_dim, rope_dim,
traditional=args.rope_traditional, traditional=args.rope_traditional,
base=args.rope_theta, base=args.rope_theta,
scale=rope_scale, scale=rope_scale,
@ -190,7 +193,8 @@ class Model(nn.Module):
super().__init__() super().__init__()
self.model_type = args.model_type self.model_type = args.model_type
self.model = Phi3Model(args) self.model = Phi3Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args self.args = args
def __call__( def __call__(
@ -200,7 +204,11 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out = self.model(inputs, mask, cache) out = self.model(inputs, mask, cache)
return self.lm_head(out) if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property @property
def layers(self): def layers(self):

View File

@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module):
+ math.log(max_position_embeddings / original_max_position_embeddings) + math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings) / math.log(original_max_position_embeddings)
) )
self.dim = dims
def __call__(self, x, offset: int = 0): def __call__(self, x, offset: int = 0):
x[..., : self.dim] = self.scale * x[..., : self.dim]
return mx.fast.rope( return mx.fast.rope(
self.scale * x, x,
x.shape[-1], self.dim,
traditional=False, traditional=False,
base=None, base=None,
scale=1.0, scale=1.0,

View File

@ -191,6 +191,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
"*.py", "*.py",
"tokenizer.model", "tokenizer.model",
"*.tiktoken", "*.tiktoken",
"tiktoken.model",
"*.txt", "*.txt",
"*.jsonl", "*.jsonl",
], ],
@ -1015,6 +1016,46 @@ def save_config(
json.dump(config, fid, indent=4) json.dump(config, fid, indent=4)
def mixed_quant_predicate_builder(
low_bits: int = 4, high_bits: int = 4, group_size: int = 64
) -> Callable[[str, nn.Module, dict], Union[bool, dict]]:
def mixed_quant_predicate(
path: str,
module: nn.Module,
config: dict,
) -> Union[bool, dict]:
"""Implements mixed quantization predicates with similar choices to, for example, llama.cpp's Q4_K_M.
Ref: https://github.com/ggerganov/llama.cpp/blob/917786f43d0f29b7c77a0c56767c0fa4df68b1c5/src/llama.cpp#L5265
By Alex Barron: https://gist.github.com/barronalex/84addb8078be21969f1690c1454855f3
"""
if not hasattr(module, "to_quantized"):
return False
index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0
num_layers = config["num_hidden_layers"]
use_more_bits = (
index < num_layers // 8
or index >= 7 * num_layers // 8
or (index - num_layers // 8) % 3 == 2
)
if "v_proj" in path and use_more_bits:
return {"group_size": group_size, "bits": high_bits}
if "down_proj" in path and use_more_bits:
return {"group_size": group_size, "bits": high_bits}
if "lm_head" in path:
return {"group_size": group_size, "bits": high_bits}
return {"group_size": group_size, "bits": low_bits}
return mixed_quant_predicate
mixed_3_6 = mixed_quant_predicate_builder(low_bits=3)
mixed_2_6 = mixed_quant_predicate_builder(low_bits=2)
def convert( def convert(
hf_path: str, hf_path: str,
mlx_path: str = "mlx_model", mlx_path: str = "mlx_model",