Merge branch 'main' into adding-orpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-10 10:51:01 +01:00 committed by GitHub
commit 575ece6ef0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 492 additions and 139 deletions

View File

@ -78,6 +78,14 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.
#### Prompt Masking
The default training computes a loss for every token in the sample. You can
ignore the prompt and compute loss for just the completion by passing
`--mask-prompt`. Note this is only supported for `chat` and `completion`
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.
### ORPO Training
Odds Ratio Preference Optimization (ORPO) training fine-tunes models using human preference data. Usage:
@ -343,11 +351,27 @@ hf_dataset:
- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
dataset. Use `chat_feature` to specify the key for a chat dataset.
- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.
You can specify a list of Hugging Face datasets with a list of records each
with the same structure as above. For example:
```yaml
hf_dataset:
- name: "Open-Orca/OpenOrca"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
prompt_feature: "question"
completion_feature: "response"
- name: "trl-lib/ultrafeedback_binarized"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
chat_feature: "chosen"
```
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).

View File

@ -152,7 +152,7 @@ def main():
print("Saving...")
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["chat_template"] = json.dumps(tokenizer.chat_template)
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
save_prompt_cache(args.prompt_cache_file, cache, metadata)

View File

@ -23,7 +23,6 @@ response = generate(
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)

View File

@ -93,6 +93,12 @@ def setup_arg_parser():
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--chat-template-config",
help="Additional config for `apply_chat_template`. Should be a dictionary of"
" string keys to values represented as a JSON decodable string.",
default=None,
)
parser.add_argument(
"--verbose",
type=str2bool,
@ -149,7 +155,6 @@ def setup_arg_parser():
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
# Load the prompt cache and metadata if a cache file is provided
@ -195,11 +200,15 @@ def main():
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
template_kwargs = {}
if args.chat_template_config is not None:
template_kwargs = json.loads(args.chat_template_config)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
tokenizer.chat_template = json.loads(metadata["chat_template"])
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt
@ -209,8 +218,12 @@ def main():
else:
messages = []
messages.append({"role": "user", "content": prompt})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages,
tokenize=False,
add_generation_prompt=True,
**template_kwargs,
)
# Treat the prompt as a suffix assuming that the prefix is in the

View File

@ -101,6 +101,14 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)
parser.add_argument(
"--training-mode",
type=str,

View File

@ -282,12 +282,12 @@ class MoEGate(nn.Module):
if self.topk_method == "group_limited_greedy":
bsz, seq_len = x.shape[:2]
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = scores.max(axis=-1)
group_scores = scores.max(axis=-1, keepdims=True)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
scores = mx.put_along_axis(
scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
)
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k

View File

@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module):
return down_proj
@mx.compile
def group_expert_select(
gates,
e_score_correction_bias,
top_k,
n_group,
topk_group,
routed_scaling_factor,
norm_topk_prob,
):
k = top_k
scores = mx.sigmoid(gates.astype(mx.float32))
scores = scores + e_score_correction_bias
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
k = n_group - topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
scores = mx.flatten(scores, -2, -1)
k = top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if top_k > 1 and norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * routed_scaling_factor
return inds, scores
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
@ -279,38 +311,22 @@ class MoEGate(nn.Module):
self.norm_topk_prob = config.norm_topk_prob
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
assert config.topk_method == "noaux_tc", "Unsupported topk method."
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.sigmoid(gates.astype(mx.float32))
assert self.topk_method == "noaux_tc", "Unsupported topk method."
bsz, seq_len = x.shape[:2]
scores = scores + self.e_score_correction_bias
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if self.top_k > 1 and self.norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores
return group_expert_select(
x @ self.weight.T,
self.e_score_correction_bias,
self.top_k,
self.n_group,
self.topk_group,
self.routed_scaling_factor,
self.norm_topk_prob,
)
class DeepseekV3MoE(nn.Module):

View File

@ -0,0 +1,195 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
logits_scaling: float
attention_multiplier: float
embedding_multiplier: float
residual_multiplier: float
max_position_embeddings: int
num_key_value_heads: int
attention_bias: bool
mlp_bias: bool
rope_theta: float
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = args.attention_multiplier
attention_bias = args.attention_bias
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
False,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.residual_multiplier = args.residual_multiplier
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * self.residual_multiplier
r = self.mlp(self.post_attention_layernorm(h))
out = h + r * self.residual_multiplier
return out
class GraniteModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.embedding_multiplier = args.embedding_multiplier
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs) * self.embedding_multiplier
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GraniteModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.logits_scaling = args.logits_scaling
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out / self.logits_scaling
@property
def layers(self):
return self.model.layers

View File

@ -76,7 +76,6 @@ class Attention(nn.Module):
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
if kv_proj:
self.k_proj = nn.Linear(
@ -107,7 +106,6 @@ class Attention(nn.Module):
B, L, D = x.shape
queries = self.q_proj(x)
if kv_states is None:
keys, values = self.k_proj(x), self.v_proj(x)
kv_states = keys, values
@ -198,7 +196,10 @@ class DecoderLayer(nn.Module):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(kv_proj, args)
self.mlp = MoeBlock(args)
if args.num_experts == 1:
self.mlp = MLP(args.hidden_size, args.intermediate_size)
else:
self.mlp = MoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
@ -231,7 +232,10 @@ class HunYuanModel(nn.Module):
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
DecoderLayer(
args=args,
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
)
for i in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
@ -251,7 +255,7 @@ class HunYuanModel(nn.Module):
cache = [None] * len(self.layers)
for i, (layer, c) in enumerate(zip(self.layers, cache)):
if i % self.args.cla_share_factor == 0:
if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
shared_kv_states = None
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
@ -275,6 +279,29 @@ class Model(nn.Module):
return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights):
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
new_weights = {}
D = self.args.hidden_size
n_kv_heads = self.args.num_key_value_heads
n_kv_groups = self.args.num_attention_heads // n_kv_heads
head_dim = D // self.args.num_attention_heads
for k, v in weights.items():
if "qkv_proj" in k:
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
k_new = k.replace("qkv_proj", k_up)
new_weights[k_new] = mx.flatten(v_new, 0, 2)
elif "gate_and_up_proj" in k:
splits = v.split(2, axis=0)
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
k_new = k.replace("gate_and_up_proj", k_up)
new_weights[k_new] = v_new
else:
new_weights[k] = v
weights = new_weights
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):

View File

@ -1,5 +1,6 @@
import json
from functools import partial
from typing import List
from transformers import AutoTokenizer
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
detokenizer_class,
eos_token_ids=eos_token_ids,
)
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
removed_bos = sequence if sequence[0] != bos else sequence[1:]
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos

View File

@ -1,6 +1,8 @@
import itertools
import json
import types
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from transformers import PreTrainedTokenizer
@ -92,14 +94,24 @@ class ChatDataset:
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
self._data = [
tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
chat_key: str = "messages",
mask_prompt: bool = False,
):
self._data = []
for d in data:
messages = d[chat_key]
tools = d.get("tools", None)
tokens = tokenizer.apply_chat_template(messages, tools=tools)
if mask_prompt:
messages = messages[:-1]
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int):
return self._data[idx]
@ -121,16 +133,25 @@ class CompletionsDataset:
tokenizer: PreTrainedTokenizer,
prompt_key: str,
completion_key: str,
mask_prompt: bool,
):
self._data = [
tokenizer.apply_chat_template(
self._data = []
for d in data:
tokens = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
if mask_prompt:
offset = len(
tokenizer.apply_chat_template(
[{"role": "user", "content": d[prompt_key]}]
)
)
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int):
return self._data[idx]
@ -139,52 +160,60 @@ class CompletionsDataset:
return len(self._data)
class ConcatenatedDataset:
def __init__(self, data: List[Any]):
self._data = list(itertools.chain(*data))
def __getitem__(self, idx: int):
return self._data[idx]
def __len__(self):
return len(self._data)
def create_dataset(
args,
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config,
):
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
mask_prompt = getattr(config, "mask_prompt", False)
prompt_feature = getattr(config, "prompt_feature", "prompt")
text_feature = getattr(config, "text_feature", "text")
completion_feature = getattr(config, "completion_feature", "completion")
chat_feature = getattr(config, "chat_feature", "messages")
sample = data[0]
if args.training_mode == "normal":
if "messages" in sample:
return ChatDataset(data, tokenizer)
if chat_feature in sample:
return ChatDataset(data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample:
return Dataset(data, tokenizer)
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature, mask_prompt)
elif text_feature in sample:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(data, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
)
elif args.training_mode == "orpo":
else:
if "chosen" in sample and "rejected" in sample:
return ORPODataset(data, tokenizer)
else:
raise ValueError(
"Unsupported training mode, check the supported training modes and their formats here:\n"
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#training-modes."
)
def load_local_dataset(
args,
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config,
):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(args, data, tokenizer, prompt_feature, completion_feature)
return create_dataset(data, tokenizer, config)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
@ -195,8 +224,7 @@ def load_hf_dataset(
args,
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config,
):
from datasets import exceptions, load_dataset
@ -207,9 +235,7 @@ def load_hf_dataset(
train, valid, test = [
(
create_dataset(
args, dataset[n], tokenizer, prompt_feature, completion_feature
)
create_dataset(args, dataset[n], tokenizer, config)
if n in dataset.keys()
else []
)
@ -225,42 +251,61 @@ def load_hf_dataset(
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets
hf_args = args.hf_dataset
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
def create_hf_dataset(dataset_name, config, split, hf_config):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
**hf_config,
)
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
return create_dataset(ds, tokenizer, config)
dataset_collection = args.hf_dataset
if isinstance(dataset_collection, dict):
dataset_collection = [dataset_collection]
collection = []
for ds in dataset_collection:
ds_name = ds["name"]
print(f"Loading Hugging Face dataset {ds_name}.")
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
config = types.SimpleNamespace(**ds)
hf_config = ds.get("config", {})
if args.train:
train_split = ds.get("train_split", "train[:80%]")
valid_split = ds.get("valid_split", "train[-10%:]")
train = create_hf_dataset(
ds_name,
config,
train_split,
hf_config,
)
valid = create_hf_dataset(
ds_name,
config,
valid_split,
hf_config,
)
else:
train, valid = [], []
if args.train:
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
else:
test = []
if args.test:
test_split = ds.get("test_split")
test = create_hf_dataset(
ds_name,
config,
test_split,
hf_config,
)
else:
test = []
return train, valid, test
collection.append((train, valid, test))
if len(collection) == 1:
return collection[0]
# Otherwise concatenate them
return tuple(map(ConcatenatedDataset, zip(*collection)))
def load_dataset(args, tokenizer: PreTrainedTokenizer):
@ -268,18 +313,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists():
train, valid, test = load_local_dataset(
args, data_path, tokenizer, prompt_feature, completion_feature
)
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(
args, args.data, tokenizer, prompt_feature, completion_feature
)
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
if args.train and len(train) == 0:
raise ValueError(

View File

@ -5,13 +5,16 @@ import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from .datasets import CompletionsDataset
def grad_checkpoint(layer):
@ -63,20 +66,30 @@ class TrainingArgs:
)
def default_loss(model, inputs, targets, lengths):
def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
def iterate_batches(
dataset,
tokenizer,
batch_size,
max_seq_length,
train=False,
):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
indices = np.random.permutation(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
truncated_length # Update lengths to match truncated lengths
)
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
yield batch, mx.array(list(zip(offsets, lengths)))
if not train:
break

View File

@ -94,6 +94,7 @@ def linear_to_lora_layers(
"phimoe",
"gemma",
"gemma2",
"granite",
"helium",
"starcoder2",
"cohere",

View File

@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self):
hf_args = {
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
}
args = types.SimpleNamespace(
hf_dataset={
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
},
hf_dataset=hf_args,
test=False,
train=True,
)
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0)
args = types.SimpleNamespace(
hf_dataset=[hf_args, hf_args],
test=False,
train=True,
)
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
self.assertEqual(2 * len(train), len(train_double))
self.assertEqual(2 * len(valid), len(valid_double))
self.assertEqual(2 * len(test), len(test_double))
if __name__ == "__main__":
unittest.main()

View File

@ -134,9 +134,7 @@ def find_alignment(
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
# consider only the logits associated with predicting text
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
sampled_logits.dtype
)
token_probs = mx.softmax(sampled_logits, precise=True, axis=-1)
text_token_probs = mx.take_along_axis(
token_probs, mx.array(text_tokens)[:, None], axis=1
).squeeze(1)
@ -144,10 +142,11 @@ def find_alignment(
# heads * tokens * frames
weights = mx.stack(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
[cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()]
)
weights = weights[:, :, : num_frames // 2]
weights = mx.softmax(weights * qk_scale, axis=-1)
weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
weights = weights.astype(mx.float32)
mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std

View File

@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state)
return out, qk.astype(mx.float32)
return out, qk
class ResidualAttentionBlock(nn.Module):