mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Merge branch 'main' into adding-orpo-training
This commit is contained in:
commit
575ece6ef0
@ -78,6 +78,14 @@ You can specify the output location with `--adapter-path`.
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||
|
||||
#### Prompt Masking
|
||||
|
||||
The default training computes a loss for every token in the sample. You can
|
||||
ignore the prompt and compute loss for just the completion by passing
|
||||
`--mask-prompt`. Note this is only supported for `chat` and `completion`
|
||||
datasets. For `chat` datasets the final message in the message list is
|
||||
considered the completion. See the [dataset section](#Data) for more details.
|
||||
|
||||
### ORPO Training
|
||||
|
||||
Odds Ratio Preference Optimization (ORPO) training fine-tunes models using human preference data. Usage:
|
||||
@ -343,11 +351,27 @@ hf_dataset:
|
||||
|
||||
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
||||
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
||||
dataset.
|
||||
dataset. Use `chat_feature` to specify the key for a chat dataset.
|
||||
|
||||
- To specify the train, valid, or test splits, set the corresponding
|
||||
`{train,valid,test}_split` argument.
|
||||
|
||||
You can specify a list of Hugging Face datasets with a list of records each
|
||||
with the same structure as above. For example:
|
||||
|
||||
```yaml
|
||||
hf_dataset:
|
||||
- name: "Open-Orca/OpenOrca"
|
||||
train_split: "train[:90%]"
|
||||
valid_split: "train[-10%:]"
|
||||
prompt_feature: "question"
|
||||
completion_feature: "response"
|
||||
- name: "trl-lib/ultrafeedback_binarized"
|
||||
train_split: "train[:90%]"
|
||||
valid_split: "train[-10%:]"
|
||||
chat_feature: "chosen"
|
||||
```
|
||||
|
||||
- Arguments specified in `config` will be passed as keyword arguments to
|
||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||
|
||||
|
@ -152,7 +152,7 @@ def main():
|
||||
print("Saving...")
|
||||
metadata = {}
|
||||
metadata["model"] = args.model
|
||||
metadata["chat_template"] = tokenizer.chat_template
|
||||
metadata["chat_template"] = json.dumps(tokenizer.chat_template)
|
||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||
|
||||
|
@ -23,7 +23,6 @@ response = generate(
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
temp=0.0,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
|
@ -93,6 +93,12 @@ def setup_arg_parser():
|
||||
action="store_true",
|
||||
help="Use the default chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-template-config",
|
||||
help="Additional config for `apply_chat_template`. Should be a dictionary of"
|
||||
" string keys to values represented as a JSON decodable string.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=str2bool,
|
||||
@ -149,7 +155,6 @@ def setup_arg_parser():
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
# Load the prompt cache and metadata if a cache file is provided
|
||||
@ -195,11 +200,15 @@ def main():
|
||||
for eos_token in args.extra_eos_token:
|
||||
tokenizer.add_eos_token(eos_token)
|
||||
|
||||
template_kwargs = {}
|
||||
if args.chat_template_config is not None:
|
||||
template_kwargs = json.loads(args.chat_template_config)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
tokenizer.chat_template = json.loads(metadata["chat_template"])
|
||||
|
||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||
@ -209,8 +218,12 @@ def main():
|
||||
else:
|
||||
messages = []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
**template_kwargs,
|
||||
)
|
||||
|
||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||
|
@ -101,6 +101,14 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-prompt",
|
||||
action="store_true",
|
||||
help="Mask the prompt in the loss when training",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--training-mode",
|
||||
type=str,
|
||||
|
@ -282,12 +282,12 @@ class MoEGate(nn.Module):
|
||||
if self.topk_method == "group_limited_greedy":
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = scores.max(axis=-1)
|
||||
group_scores = scores.max(axis=-1, keepdims=True)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||
scores = mx.put_along_axis(
|
||||
scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
|
||||
)
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
|
@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module):
|
||||
return down_proj
|
||||
|
||||
|
||||
@mx.compile
|
||||
def group_expert_select(
|
||||
gates,
|
||||
e_score_correction_bias,
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
routed_scaling_factor,
|
||||
norm_topk_prob,
|
||||
):
|
||||
|
||||
k = top_k
|
||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||
scores = scores + e_score_correction_bias
|
||||
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
|
||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
|
||||
k = n_group - topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
|
||||
scores = mx.flatten(scores, -2, -1)
|
||||
|
||||
k = top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
if top_k > 1 and norm_topk_prob:
|
||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||
scores = scores / denominator
|
||||
scores = scores * routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
@ -279,38 +311,22 @@ class MoEGate(nn.Module):
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
||||
assert config.topk_method == "noaux_tc", "Unsupported topk method."
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
|
||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||
|
||||
assert self.topk_method == "noaux_tc", "Unsupported topk method."
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores + self.e_score_correction_bias
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||
scores = scores / denominator
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
return group_expert_select(
|
||||
x @ self.weight.T,
|
||||
self.e_score_correction_bias,
|
||||
self.top_k,
|
||||
self.n_group,
|
||||
self.topk_group,
|
||||
self.routed_scaling_factor,
|
||||
self.norm_topk_prob,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
|
195
llms/mlx_lm/models/granite.py
Normal file
195
llms/mlx_lm/models/granite.py
Normal file
@ -0,0 +1,195 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
logits_scaling: float
|
||||
attention_multiplier: float
|
||||
embedding_multiplier: float
|
||||
residual_multiplier: float
|
||||
max_position_embeddings: int
|
||||
num_key_value_heads: int
|
||||
attention_bias: bool
|
||||
mlp_bias: bool
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
|
||||
self.scale = args.attention_multiplier
|
||||
attention_bias = args.attention_bias
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
self.rope = initialize_rope(
|
||||
self.head_dim,
|
||||
args.rope_theta,
|
||||
False,
|
||||
args.rope_scaling,
|
||||
args.max_position_embeddings,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.residual_multiplier = args.residual_multiplier
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r * self.residual_multiplier
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r * self.residual_multiplier
|
||||
return out
|
||||
|
||||
|
||||
class GraniteModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.embedding_multiplier = args.embedding_multiplier
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs) * self.embedding_multiplier
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = GraniteModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.logits_scaling = args.logits_scaling
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out / self.logits_scaling
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
@ -76,7 +76,6 @@ class Attention(nn.Module):
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
if kv_proj:
|
||||
self.k_proj = nn.Linear(
|
||||
@ -107,7 +106,6 @@ class Attention(nn.Module):
|
||||
B, L, D = x.shape
|
||||
|
||||
queries = self.q_proj(x)
|
||||
|
||||
if kv_states is None:
|
||||
keys, values = self.k_proj(x), self.v_proj(x)
|
||||
kv_states = keys, values
|
||||
@ -198,7 +196,10 @@ class DecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(kv_proj, args)
|
||||
self.mlp = MoeBlock(args)
|
||||
if args.num_experts == 1:
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
else:
|
||||
self.mlp = MoeBlock(args)
|
||||
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
@ -231,7 +232,10 @@ class HunYuanModel(nn.Module):
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
|
||||
DecoderLayer(
|
||||
args=args,
|
||||
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
|
||||
)
|
||||
for i in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
@ -251,7 +255,7 @@ class HunYuanModel(nn.Module):
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
||||
if i % self.args.cla_share_factor == 0:
|
||||
if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
|
||||
shared_kv_states = None
|
||||
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
||||
|
||||
@ -275,6 +279,29 @@ class Model(nn.Module):
|
||||
return self.model.embed_tokens.as_linear(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
|
||||
new_weights = {}
|
||||
D = self.args.hidden_size
|
||||
n_kv_heads = self.args.num_key_value_heads
|
||||
n_kv_groups = self.args.num_attention_heads // n_kv_heads
|
||||
head_dim = D // self.args.num_attention_heads
|
||||
for k, v in weights.items():
|
||||
if "qkv_proj" in k:
|
||||
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
|
||||
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
|
||||
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
|
||||
k_new = k.replace("qkv_proj", k_up)
|
||||
new_weights[k_new] = mx.flatten(v_new, 0, 2)
|
||||
elif "gate_and_up_proj" in k:
|
||||
splits = v.split(2, axis=0)
|
||||
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
|
||||
k_new = k.replace("gate_and_up_proj", k_up)
|
||||
new_weights[k_new] = v_new
|
||||
else:
|
||||
new_weights[k] = v
|
||||
weights = new_weights
|
||||
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
||||
detokenizer_class,
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
|
||||
removed_bos = sequence if sequence[0] != bos else sequence[1:]
|
||||
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
|
||||
|
@ -1,6 +1,8 @@
|
||||
import itertools
|
||||
import json
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -92,14 +94,24 @@ class ChatDataset:
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
d["messages"],
|
||||
tools=d.get("tools", None),
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
chat_key: str = "messages",
|
||||
mask_prompt: bool = False,
|
||||
):
|
||||
self._data = []
|
||||
for d in data:
|
||||
messages = d[chat_key]
|
||||
tools = d.get("tools", None)
|
||||
tokens = tokenizer.apply_chat_template(messages, tools=tools)
|
||||
if mask_prompt:
|
||||
messages = messages[:-1]
|
||||
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
|
||||
self._data.append((tokens, offset))
|
||||
else:
|
||||
self._data.append(tokens)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
@ -121,16 +133,25 @@ class CompletionsDataset:
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str,
|
||||
completion_key: str,
|
||||
mask_prompt: bool,
|
||||
):
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
self._data = []
|
||||
for d in data:
|
||||
tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[completion_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
if mask_prompt:
|
||||
offset = len(
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": d[prompt_key]}]
|
||||
)
|
||||
)
|
||||
self._data.append((tokens, offset))
|
||||
else:
|
||||
self._data.append(tokens)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
@ -139,52 +160,60 @@ class CompletionsDataset:
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ConcatenatedDataset:
|
||||
def __init__(self, data: List[Any]):
|
||||
self._data = list(itertools.chain(*data))
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def create_dataset(
|
||||
args,
|
||||
data,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
config,
|
||||
):
|
||||
prompt_feature = prompt_feature or "prompt"
|
||||
completion_feature = completion_feature or "completion"
|
||||
mask_prompt = getattr(config, "mask_prompt", False)
|
||||
prompt_feature = getattr(config, "prompt_feature", "prompt")
|
||||
text_feature = getattr(config, "text_feature", "text")
|
||||
completion_feature = getattr(config, "completion_feature", "completion")
|
||||
chat_feature = getattr(config, "chat_feature", "messages")
|
||||
sample = data[0]
|
||||
|
||||
|
||||
if args.training_mode == "normal":
|
||||
if "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
if chat_feature in sample:
|
||||
return ChatDataset(data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt)
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
elif "text" in sample:
|
||||
return Dataset(data, tokenizer)
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature, mask_prompt)
|
||||
elif text_feature in sample:
|
||||
if mask_prompt:
|
||||
raise ValueError("Prompt masking not supported for text dataset.")
|
||||
return Dataset(data, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||
)
|
||||
elif args.training_mode == "orpo":
|
||||
else:
|
||||
if "chosen" in sample and "rejected" in sample:
|
||||
return ORPODataset(data, tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported training mode, check the supported training modes and their formats here:\n"
|
||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#training-modes."
|
||||
)
|
||||
|
||||
|
||||
def load_local_dataset(
|
||||
args,
|
||||
data_path: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
config,
|
||||
):
|
||||
def load_subset(path):
|
||||
if not path.exists():
|
||||
return []
|
||||
with open(path, "r") as fid:
|
||||
data = [json.loads(l) for l in fid]
|
||||
return create_dataset(args, data, tokenizer, prompt_feature, completion_feature)
|
||||
|
||||
return create_dataset(data, tokenizer, config)
|
||||
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||
@ -195,8 +224,7 @@ def load_hf_dataset(
|
||||
args,
|
||||
data_id: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
config,
|
||||
):
|
||||
from datasets import exceptions, load_dataset
|
||||
|
||||
@ -207,9 +235,7 @@ def load_hf_dataset(
|
||||
|
||||
train, valid, test = [
|
||||
(
|
||||
create_dataset(
|
||||
args, dataset[n], tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
create_dataset(args, dataset[n], tokenizer, config)
|
||||
if n in dataset.keys()
|
||||
else []
|
||||
)
|
||||
@ -225,42 +251,61 @@ def load_hf_dataset(
|
||||
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
import datasets
|
||||
|
||||
hf_args = args.hf_dataset
|
||||
dataset_name = hf_args["name"]
|
||||
print(f"Loading Hugging Face dataset {dataset_name}.")
|
||||
text_feature = hf_args.get("text_feature")
|
||||
prompt_feature = hf_args.get("prompt_feature")
|
||||
completion_feature = hf_args.get("completion_feature")
|
||||
|
||||
def create_hf_dataset(split: str = None):
|
||||
def create_hf_dataset(dataset_name, config, split, hf_config):
|
||||
ds = datasets.load_dataset(
|
||||
dataset_name,
|
||||
split=split,
|
||||
**hf_args.get("config", {}),
|
||||
**hf_config,
|
||||
)
|
||||
if prompt_feature and completion_feature:
|
||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||
elif text_feature:
|
||||
return Dataset(ds, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Specify either a prompt and completion feature or a text "
|
||||
"feature for the Hugging Face dataset."
|
||||
return create_dataset(ds, tokenizer, config)
|
||||
|
||||
dataset_collection = args.hf_dataset
|
||||
if isinstance(dataset_collection, dict):
|
||||
dataset_collection = [dataset_collection]
|
||||
|
||||
collection = []
|
||||
for ds in dataset_collection:
|
||||
ds_name = ds["name"]
|
||||
print(f"Loading Hugging Face dataset {ds_name}.")
|
||||
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
|
||||
config = types.SimpleNamespace(**ds)
|
||||
hf_config = ds.get("config", {})
|
||||
if args.train:
|
||||
train_split = ds.get("train_split", "train[:80%]")
|
||||
valid_split = ds.get("valid_split", "train[-10%:]")
|
||||
train = create_hf_dataset(
|
||||
ds_name,
|
||||
config,
|
||||
train_split,
|
||||
hf_config,
|
||||
)
|
||||
valid = create_hf_dataset(
|
||||
ds_name,
|
||||
config,
|
||||
valid_split,
|
||||
hf_config,
|
||||
)
|
||||
else:
|
||||
train, valid = [], []
|
||||
|
||||
if args.train:
|
||||
train_split = hf_args.get("train_split", "train[:80%]")
|
||||
valid_split = hf_args.get("valid_split", "train[-10%:]")
|
||||
train = create_hf_dataset(split=train_split)
|
||||
valid = create_hf_dataset(split=valid_split)
|
||||
else:
|
||||
train, valid = [], []
|
||||
if args.test:
|
||||
test = create_hf_dataset(split=hf_args.get("test_split"))
|
||||
else:
|
||||
test = []
|
||||
if args.test:
|
||||
test_split = ds.get("test_split")
|
||||
test = create_hf_dataset(
|
||||
ds_name,
|
||||
config,
|
||||
test_split,
|
||||
hf_config,
|
||||
)
|
||||
else:
|
||||
test = []
|
||||
|
||||
return train, valid, test
|
||||
collection.append((train, valid, test))
|
||||
|
||||
if len(collection) == 1:
|
||||
return collection[0]
|
||||
|
||||
# Otherwise concatenate them
|
||||
return tuple(map(ConcatenatedDataset, zip(*collection)))
|
||||
|
||||
|
||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
@ -268,18 +313,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||
else:
|
||||
data_path = Path(args.data)
|
||||
|
||||
prompt_feature = getattr(args, "prompt_feature", None)
|
||||
completion_feature = getattr(args, "completion_feature", None)
|
||||
if data_path.exists():
|
||||
train, valid, test = load_local_dataset(
|
||||
args, data_path, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
|
||||
else:
|
||||
print(f"Loading Hugging Face dataset {args.data}.")
|
||||
train, valid, test = load_hf_dataset(
|
||||
args, args.data, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
||||
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
|
@ -5,13 +5,16 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .datasets import CompletionsDataset
|
||||
|
||||
|
||||
def grad_checkpoint(layer):
|
||||
@ -63,20 +66,30 @@ class TrainingArgs:
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
def default_loss(model, batch, lengths):
|
||||
inputs = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
steps = mx.arange(1, targets.shape[1] + 1)
|
||||
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||
ntoks = length_mask.sum()
|
||||
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||
ntoks = mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
def iterate_batches(
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
max_seq_length,
|
||||
train=False,
|
||||
):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
if len(dataset) < batch_size:
|
||||
@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
if len(batch[0]) == 2:
|
||||
batch, offsets = zip(*batch)
|
||||
else:
|
||||
offsets = [0] * len(batch)
|
||||
lengths = [len(x) for x in batch]
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
batch = mx.array(batch_arr)
|
||||
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
@ -94,6 +94,7 @@ def linear_to_lora_layers(
|
||||
"phimoe",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"granite",
|
||||
"helium",
|
||||
"starcoder2",
|
||||
"cohere",
|
||||
|
@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
|
||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||
|
||||
def test_hf(self):
|
||||
hf_args = {
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
"train_split": "train[:2%]",
|
||||
"valid_split": "train[-2%:]",
|
||||
}
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset={
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
"train_split": "train[:2%]",
|
||||
"valid_split": "train[-2%:]",
|
||||
},
|
||||
hf_dataset=hf_args,
|
||||
test=False,
|
||||
train=True,
|
||||
)
|
||||
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertEqual(len(test), 0)
|
||||
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset=[hf_args, hf_args],
|
||||
test=False,
|
||||
train=True,
|
||||
)
|
||||
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(2 * len(train), len(train_double))
|
||||
self.assertEqual(2 * len(valid), len(valid_double))
|
||||
self.assertEqual(2 * len(test), len(test_double))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -134,9 +134,7 @@ def find_alignment(
|
||||
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
||||
# consider only the logits associated with predicting text
|
||||
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
||||
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
|
||||
sampled_logits.dtype
|
||||
)
|
||||
token_probs = mx.softmax(sampled_logits, precise=True, axis=-1)
|
||||
text_token_probs = mx.take_along_axis(
|
||||
token_probs, mx.array(text_tokens)[:, None], axis=1
|
||||
).squeeze(1)
|
||||
@ -144,10 +142,11 @@ def find_alignment(
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = mx.stack(
|
||||
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
|
||||
[cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()]
|
||||
)
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = mx.softmax(weights * qk_scale, axis=-1)
|
||||
weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
|
||||
weights = weights.astype(mx.float32)
|
||||
mean = mx.mean(weights, axis=-2, keepdims=True)
|
||||
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
||||
weights = (weights - mean) / std
|
||||
|
@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
|
||||
w = mx.softmax(qk, axis=-1, precise=True)
|
||||
out = (w @ v).transpose(0, 2, 1, 3)
|
||||
out = out.reshape(n_batch, n_ctx, n_state)
|
||||
return out, qk.astype(mx.float32)
|
||||
return out, qk
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user