Merge branch 'main' into adding-orpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-10 10:51:01 +01:00 committed by GitHub
commit 575ece6ef0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 492 additions and 139 deletions

View File

@ -78,6 +78,14 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`. `--resume-adapter-file <path_to_adapters.safetensors>`.
#### Prompt Masking
The default training computes a loss for every token in the sample. You can
ignore the prompt and compute loss for just the completion by passing
`--mask-prompt`. Note this is only supported for `chat` and `completion`
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.
### ORPO Training ### ORPO Training
Odds Ratio Preference Optimization (ORPO) training fine-tunes models using human preference data. Usage: Odds Ratio Preference Optimization (ORPO) training fine-tunes models using human preference data. Usage:
@ -343,11 +351,27 @@ hf_dataset:
- Use `prompt_feature` and `completion_feature` to specify keys for a - Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text` `completions` dataset. Use `text_feature` to specify the key for a `text`
dataset. dataset. Use `chat_feature` to specify the key for a chat dataset.
- To specify the train, valid, or test splits, set the corresponding - To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument. `{train,valid,test}_split` argument.
You can specify a list of Hugging Face datasets with a list of records each
with the same structure as above. For example:
```yaml
hf_dataset:
- name: "Open-Orca/OpenOrca"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
prompt_feature: "question"
completion_feature: "response"
- name: "trl-lib/ultrafeedback_binarized"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
chat_feature: "chosen"
```
- Arguments specified in `config` will be passed as keyword arguments to - Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).

View File

@ -152,7 +152,7 @@ def main():
print("Saving...") print("Saving...")
metadata = {} metadata = {}
metadata["model"] = args.model metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template metadata["chat_template"] = json.dumps(tokenizer.chat_template)
metadata["tokenizer_config"] = json.dumps(tokenizer_config) metadata["tokenizer_config"] = json.dumps(tokenizer_config)
save_prompt_cache(args.prompt_cache_file, cache, metadata) save_prompt_cache(args.prompt_cache_file, cache, metadata)

View File

@ -23,7 +23,6 @@ response = generate(
tokenizer, tokenizer,
prompt=prompt, prompt=prompt,
verbose=True, verbose=True,
temp=0.0,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
) )

View File

@ -93,6 +93,12 @@ def setup_arg_parser():
action="store_true", action="store_true",
help="Use the default chat template", help="Use the default chat template",
) )
parser.add_argument(
"--chat-template-config",
help="Additional config for `apply_chat_template`. Should be a dictionary of"
" string keys to values represented as a JSON decodable string.",
default=None,
)
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
type=str2bool, type=str2bool,
@ -149,7 +155,6 @@ def setup_arg_parser():
def main(): def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
mx.random.seed(args.seed) mx.random.seed(args.seed)
# Load the prompt cache and metadata if a cache file is provided # Load the prompt cache and metadata if a cache file is provided
@ -195,11 +200,15 @@ def main():
for eos_token in args.extra_eos_token: for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token) tokenizer.add_eos_token(eos_token)
template_kwargs = {}
if args.chat_template_config is not None:
template_kwargs = json.loads(args.chat_template_config)
if args.use_default_chat_template: if args.use_default_chat_template:
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache: elif using_cache:
tokenizer.chat_template = metadata["chat_template"] tokenizer.chat_template = json.loads(metadata["chat_template"])
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt prompt = sys.stdin.read() if prompt == "-" else prompt
@ -209,8 +218,12 @@ def main():
else: else:
messages = [] messages = []
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages,
tokenize=False,
add_generation_prompt=True,
**template_kwargs,
) )
# Treat the prompt as a suffix assuming that the prefix is in the # Treat the prompt as a suffix assuming that the prefix is in the

View File

@ -101,6 +101,14 @@ def build_parser():
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)
parser.add_argument( parser.add_argument(
"--training-mode", "--training-mode",
type=str, type=str,

View File

@ -282,12 +282,12 @@ class MoEGate(nn.Module):
if self.topk_method == "group_limited_greedy": if self.topk_method == "group_limited_greedy":
bsz, seq_len = x.shape[:2] bsz, seq_len = x.shape[:2]
scores = scores.reshape(bsz, seq_len, self.n_group, -1) scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = scores.max(axis=-1) group_scores = scores.max(axis=-1, keepdims=True)
k = self.n_group - self.topk_group k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) scores = mx.put_along_axis(
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
scores[batch_idx, seq_idx, group_idx] = 0.0 )
scores = scores.reshape(bsz, seq_len, -1) scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k k = self.top_k

View File

@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module):
return down_proj return down_proj
@mx.compile
def group_expert_select(
gates,
e_score_correction_bias,
top_k,
n_group,
topk_group,
routed_scaling_factor,
norm_topk_prob,
):
k = top_k
scores = mx.sigmoid(gates.astype(mx.float32))
scores = scores + e_score_correction_bias
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
k = n_group - topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
scores = mx.flatten(scores, -2, -1)
k = top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if top_k > 1 and norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * routed_scaling_factor
return inds, scores
class MoEGate(nn.Module): class MoEGate(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
@ -279,38 +311,22 @@ class MoEGate(nn.Module):
self.norm_topk_prob = config.norm_topk_prob self.norm_topk_prob = config.norm_topk_prob
self.n_routed_experts = config.n_routed_experts self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group self.n_group = config.n_group
self.topk_group = config.topk_group self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
assert config.topk_method == "noaux_tc", "Unsupported topk method."
def __call__(self, x): def __call__(self, x):
gates = x @ self.weight.T return group_expert_select(
x @ self.weight.T,
scores = mx.sigmoid(gates.astype(mx.float32)) self.e_score_correction_bias,
self.top_k,
assert self.topk_method == "noaux_tc", "Unsupported topk method." self.n_group,
bsz, seq_len = x.shape[:2] self.topk_group,
scores = scores + self.e_score_correction_bias self.routed_scaling_factor,
scores = scores.reshape(bsz, seq_len, self.n_group, -1) self.norm_topk_prob,
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1) )
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if self.top_k > 1 and self.norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV3MoE(nn.Module): class DeepseekV3MoE(nn.Module):

View File

@ -0,0 +1,195 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
logits_scaling: float
attention_multiplier: float
embedding_multiplier: float
residual_multiplier: float
max_position_embeddings: int
num_key_value_heads: int
attention_bias: bool
mlp_bias: bool
rope_theta: float
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = args.attention_multiplier
attention_bias = args.attention_bias
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
False,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.residual_multiplier = args.residual_multiplier
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * self.residual_multiplier
r = self.mlp(self.post_attention_layernorm(h))
out = h + r * self.residual_multiplier
return out
class GraniteModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.embedding_multiplier = args.embedding_multiplier
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs) * self.embedding_multiplier
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GraniteModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.logits_scaling = args.logits_scaling
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out / self.logits_scaling
@property
def layers(self):
return self.model.layers

View File

@ -76,7 +76,6 @@ class Attention(nn.Module):
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5 self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
if kv_proj: if kv_proj:
self.k_proj = nn.Linear( self.k_proj = nn.Linear(
@ -107,7 +106,6 @@ class Attention(nn.Module):
B, L, D = x.shape B, L, D = x.shape
queries = self.q_proj(x) queries = self.q_proj(x)
if kv_states is None: if kv_states is None:
keys, values = self.k_proj(x), self.v_proj(x) keys, values = self.k_proj(x), self.v_proj(x)
kv_states = keys, values kv_states = keys, values
@ -198,6 +196,9 @@ class DecoderLayer(nn.Module):
super().__init__() super().__init__()
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.self_attn = Attention(kv_proj, args) self.self_attn = Attention(kv_proj, args)
if args.num_experts == 1:
self.mlp = MLP(args.hidden_size, args.intermediate_size)
else:
self.mlp = MoeBlock(args) self.mlp = MoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
@ -231,7 +232,10 @@ class HunYuanModel(nn.Module):
assert self.vocab_size > 0 assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ self.layers = [
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0) DecoderLayer(
args=args,
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
)
for i in range(args.num_hidden_layers) for i in range(args.num_hidden_layers)
] ]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
@ -251,7 +255,7 @@ class HunYuanModel(nn.Module):
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
for i, (layer, c) in enumerate(zip(self.layers, cache)): for i, (layer, c) in enumerate(zip(self.layers, cache)):
if i % self.args.cla_share_factor == 0: if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
shared_kv_states = None shared_kv_states = None
h, shared_kv_states = layer(h, mask, c, shared_kv_states) h, shared_kv_states = layer(h, mask, c, shared_kv_states)
@ -275,6 +279,29 @@ class Model(nn.Module):
return self.model.embed_tokens.as_linear(out) return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights): def sanitize(self, weights):
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
new_weights = {}
D = self.args.hidden_size
n_kv_heads = self.args.num_key_value_heads
n_kv_groups = self.args.num_attention_heads // n_kv_heads
head_dim = D // self.args.num_attention_heads
for k, v in weights.items():
if "qkv_proj" in k:
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
k_new = k.replace("qkv_proj", k_up)
new_weights[k_new] = mx.flatten(v_new, 0, 2)
elif "gate_and_up_proj" in k:
splits = v.split(2, axis=0)
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
k_new = k.replace("gate_and_up_proj", k_up)
new_weights[k_new] = v_new
else:
new_weights[k] = v
weights = new_weights
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights return weights
for l in range(self.args.num_hidden_layers): for l in range(self.args.num_hidden_layers):

View File

@ -1,5 +1,6 @@
import json import json
from functools import partial from functools import partial
from typing import List
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
detokenizer_class, detokenizer_class,
eos_token_ids=eos_token_ids, eos_token_ids=eos_token_ids,
) )
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
removed_bos = sequence if sequence[0] != bos else sequence[1:]
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos

View File

@ -1,6 +1,8 @@
import itertools
import json import json
import types
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -92,14 +94,24 @@ class ChatDataset:
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): def __init__(
self._data = [ self,
tokenizer.apply_chat_template( data: List[Dict[str, str]],
d["messages"], tokenizer: PreTrainedTokenizer,
tools=d.get("tools", None), chat_key: str = "messages",
) mask_prompt: bool = False,
for d in data ):
] self._data = []
for d in data:
messages = d[chat_key]
tools = d.get("tools", None)
tokens = tokenizer.apply_chat_template(messages, tools=tools)
if mask_prompt:
messages = messages[:-1]
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx] return self._data[idx]
@ -121,16 +133,25 @@ class CompletionsDataset:
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_key: str, prompt_key: str,
completion_key: str, completion_key: str,
mask_prompt: bool,
): ):
self._data = [ self._data = []
tokenizer.apply_chat_template( for d in data:
tokens = tokenizer.apply_chat_template(
[ [
{"role": "user", "content": d[prompt_key]}, {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]}, {"role": "assistant", "content": d[completion_key]},
], ],
) )
for d in data if mask_prompt:
] offset = len(
tokenizer.apply_chat_template(
[{"role": "user", "content": d[prompt_key]}]
)
)
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx] return self._data[idx]
@ -139,52 +160,60 @@ class CompletionsDataset:
return len(self._data) return len(self._data)
class ConcatenatedDataset:
def __init__(self, data: List[Any]):
self._data = list(itertools.chain(*data))
def __getitem__(self, idx: int):
return self._data[idx]
def __len__(self):
return len(self._data)
def create_dataset( def create_dataset(
args, args,
data, data,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config,
completion_feature: Optional[str] = None,
): ):
prompt_feature = prompt_feature or "prompt" mask_prompt = getattr(config, "mask_prompt", False)
completion_feature = completion_feature or "completion" prompt_feature = getattr(config, "prompt_feature", "prompt")
text_feature = getattr(config, "text_feature", "text")
completion_feature = getattr(config, "completion_feature", "completion")
chat_feature = getattr(config, "chat_feature", "messages")
sample = data[0] sample = data[0]
if args.training_mode == "normal": if args.training_mode == "normal":
if "messages" in sample: if chat_feature in sample:
return ChatDataset(data, tokenizer) return ChatDataset(data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt)
elif prompt_feature in sample and completion_feature in sample: elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature, mask_prompt)
elif "text" in sample: elif text_feature in sample:
return Dataset(data, tokenizer) if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(data, tokenizer, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
"Unsupported data format, check the supported formats here:\n" "Unsupported data format, check the supported formats here:\n"
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
) )
elif args.training_mode == "orpo": else:
if "chosen" in sample and "rejected" in sample: if "chosen" in sample and "rejected" in sample:
return ORPODataset(data, tokenizer) return ORPODataset(data, tokenizer)
else:
raise ValueError(
"Unsupported training mode, check the supported training modes and their formats here:\n"
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#training-modes."
)
def load_local_dataset( def load_local_dataset(
args, args,
data_path: Path, data_path: Path,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config,
completion_feature: Optional[str] = None,
): ):
def load_subset(path): def load_subset(path):
if not path.exists(): if not path.exists():
return [] return []
with open(path, "r") as fid: with open(path, "r") as fid:
data = [json.loads(l) for l in fid] data = [json.loads(l) for l in fid]
return create_dataset(args, data, tokenizer, prompt_feature, completion_feature)
return create_dataset(data, tokenizer, config)
names = ("train", "valid", "test") names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
@ -195,8 +224,7 @@ def load_hf_dataset(
args, args,
data_id: str, data_id: str,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config,
completion_feature: Optional[str] = None,
): ):
from datasets import exceptions, load_dataset from datasets import exceptions, load_dataset
@ -207,9 +235,7 @@ def load_hf_dataset(
train, valid, test = [ train, valid, test = [
( (
create_dataset( create_dataset(args, dataset[n], tokenizer, config)
args, dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys() if n in dataset.keys()
else [] else []
) )
@ -225,42 +251,61 @@ def load_hf_dataset(
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets import datasets
hf_args = args.hf_dataset def create_hf_dataset(dataset_name, config, split, hf_config):
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
ds = datasets.load_dataset( ds = datasets.load_dataset(
dataset_name, dataset_name,
split=split, split=split,
**hf_args.get("config", {}), **hf_config,
)
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
) )
return create_dataset(ds, tokenizer, config)
dataset_collection = args.hf_dataset
if isinstance(dataset_collection, dict):
dataset_collection = [dataset_collection]
collection = []
for ds in dataset_collection:
ds_name = ds["name"]
print(f"Loading Hugging Face dataset {ds_name}.")
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
config = types.SimpleNamespace(**ds)
hf_config = ds.get("config", {})
if args.train: if args.train:
train_split = hf_args.get("train_split", "train[:80%]") train_split = ds.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]") valid_split = ds.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split) train = create_hf_dataset(
valid = create_hf_dataset(split=valid_split) ds_name,
config,
train_split,
hf_config,
)
valid = create_hf_dataset(
ds_name,
config,
valid_split,
hf_config,
)
else: else:
train, valid = [], [] train, valid = [], []
if args.test: if args.test:
test = create_hf_dataset(split=hf_args.get("test_split")) test_split = ds.get("test_split")
test = create_hf_dataset(
ds_name,
config,
test_split,
hf_config,
)
else: else:
test = [] test = []
return train, valid, test collection.append((train, valid, test))
if len(collection) == 1:
return collection[0]
# Otherwise concatenate them
return tuple(map(ConcatenatedDataset, zip(*collection)))
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
@ -268,18 +313,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists(): if data_path.exists():
train, valid, test = load_local_dataset( train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
args, data_path, tokenizer, prompt_feature, completion_feature
)
else: else:
print(f"Loading Hugging Face dataset {args.data}.") print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset( train, valid, test = load_hf_dataset(args.data, tokenizer, args)
args, args.data, tokenizer, prompt_feature, completion_feature
)
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(

View File

@ -5,13 +5,16 @@ import shutil
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Union from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from .datasets import CompletionsDataset
def grad_checkpoint(layer): def grad_checkpoint(layer):
@ -63,20 +66,30 @@ class TrainingArgs:
) )
def default_loss(model, inputs, targets, lengths): def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs) logits = model(inputs)
logits = logits.astype(mx.float32) logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = length_mask.sum() ntoks = mask.sum()
ce = ce.sum() / ntoks ce = ce.sum() / ntoks
return ce, ntoks return ce, ntoks
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_batches(
dataset,
tokenizer,
batch_size,
max_seq_length,
train=False,
):
# Sort by length: # Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size: if len(dataset) < batch_size:
@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
indices = np.random.permutation(len(batch_idx)) indices = np.random.permutation(len(batch_idx))
for i in indices: for i in indices:
batch = [dataset[j] for j in batch_idx[i]] batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:
print( print(
@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
truncated_length # Update lengths to match truncated lengths truncated_length # Update lengths to match truncated lengths
) )
batch = mx.array(batch_arr) batch = mx.array(batch_arr)
yield batch, mx.array(list(zip(offsets, lengths)))
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train: if not train:
break break

View File

@ -94,6 +94,7 @@ def linear_to_lora_layers(
"phimoe", "phimoe",
"gemma", "gemma",
"gemma2", "gemma2",
"granite",
"helium", "helium",
"starcoder2", "starcoder2",
"cohere", "cohere",

View File

@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(isinstance(train, datasets.ChatDataset)) self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self): def test_hf(self):
args = types.SimpleNamespace( hf_args = {
hf_dataset={
"name": "billsum", "name": "billsum",
"prompt_feature": "text", "prompt_feature": "text",
"completion_feature": "summary", "completion_feature": "summary",
"train_split": "train[:2%]", "train_split": "train[:2%]",
"valid_split": "train[-2%:]", "valid_split": "train[-2%:]",
}, }
args = types.SimpleNamespace(
hf_dataset=hf_args,
test=False, test=False,
train=True, train=True,
) )
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0) self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0) self.assertEqual(len(test), 0)
args = types.SimpleNamespace(
hf_dataset=[hf_args, hf_args],
test=False,
train=True,
)
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
self.assertEqual(2 * len(train), len(train_double))
self.assertEqual(2 * len(valid), len(valid_double))
self.assertEqual(2 * len(test), len(test_double))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -134,9 +134,7 @@ def find_alignment(
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
# consider only the logits associated with predicting text # consider only the logits associated with predicting text
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( token_probs = mx.softmax(sampled_logits, precise=True, axis=-1)
sampled_logits.dtype
)
text_token_probs = mx.take_along_axis( text_token_probs = mx.take_along_axis(
token_probs, mx.array(text_tokens)[:, None], axis=1 token_probs, mx.array(text_tokens)[:, None], axis=1
).squeeze(1) ).squeeze(1)
@ -144,10 +142,11 @@ def find_alignment(
# heads * tokens * frames # heads * tokens * frames
weights = mx.stack( weights = mx.stack(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] [cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()]
) )
weights = weights[:, :, : num_frames // 2] weights = weights[:, :, : num_frames // 2]
weights = mx.softmax(weights * qk_scale, axis=-1) weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
weights = weights.astype(mx.float32)
mean = mx.mean(weights, axis=-2, keepdims=True) mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std weights = (weights - mean) / std

View File

@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
w = mx.softmax(qk, axis=-1, precise=True) w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3) out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state) out = out.reshape(n_batch, n_ctx, n_state)
return out, qk.astype(mx.float32) return out, qk
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):