mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
Merge branch 'main' into adding-orpo-training
This commit is contained in:
commit
575ece6ef0
@ -78,6 +78,14 @@ You can specify the output location with `--adapter-path`.
|
|||||||
You can resume fine-tuning with an existing adapter with
|
You can resume fine-tuning with an existing adapter with
|
||||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||||
|
|
||||||
|
#### Prompt Masking
|
||||||
|
|
||||||
|
The default training computes a loss for every token in the sample. You can
|
||||||
|
ignore the prompt and compute loss for just the completion by passing
|
||||||
|
`--mask-prompt`. Note this is only supported for `chat` and `completion`
|
||||||
|
datasets. For `chat` datasets the final message in the message list is
|
||||||
|
considered the completion. See the [dataset section](#Data) for more details.
|
||||||
|
|
||||||
### ORPO Training
|
### ORPO Training
|
||||||
|
|
||||||
Odds Ratio Preference Optimization (ORPO) training fine-tunes models using human preference data. Usage:
|
Odds Ratio Preference Optimization (ORPO) training fine-tunes models using human preference data. Usage:
|
||||||
@ -343,11 +351,27 @@ hf_dataset:
|
|||||||
|
|
||||||
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
||||||
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
||||||
dataset.
|
dataset. Use `chat_feature` to specify the key for a chat dataset.
|
||||||
|
|
||||||
- To specify the train, valid, or test splits, set the corresponding
|
- To specify the train, valid, or test splits, set the corresponding
|
||||||
`{train,valid,test}_split` argument.
|
`{train,valid,test}_split` argument.
|
||||||
|
|
||||||
|
You can specify a list of Hugging Face datasets with a list of records each
|
||||||
|
with the same structure as above. For example:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
hf_dataset:
|
||||||
|
- name: "Open-Orca/OpenOrca"
|
||||||
|
train_split: "train[:90%]"
|
||||||
|
valid_split: "train[-10%:]"
|
||||||
|
prompt_feature: "question"
|
||||||
|
completion_feature: "response"
|
||||||
|
- name: "trl-lib/ultrafeedback_binarized"
|
||||||
|
train_split: "train[:90%]"
|
||||||
|
valid_split: "train[-10%:]"
|
||||||
|
chat_feature: "chosen"
|
||||||
|
```
|
||||||
|
|
||||||
- Arguments specified in `config` will be passed as keyword arguments to
|
- Arguments specified in `config` will be passed as keyword arguments to
|
||||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ def main():
|
|||||||
print("Saving...")
|
print("Saving...")
|
||||||
metadata = {}
|
metadata = {}
|
||||||
metadata["model"] = args.model
|
metadata["model"] = args.model
|
||||||
metadata["chat_template"] = tokenizer.chat_template
|
metadata["chat_template"] = json.dumps(tokenizer.chat_template)
|
||||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ response = generate(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
temp=0.0,
|
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -93,6 +93,12 @@ def setup_arg_parser():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use the default chat template",
|
help="Use the default chat template",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat-template-config",
|
||||||
|
help="Additional config for `apply_chat_template`. Should be a dictionary of"
|
||||||
|
" string keys to values represented as a JSON decodable string.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose",
|
"--verbose",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -149,7 +155,6 @@ def setup_arg_parser():
|
|||||||
def main():
|
def main():
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
# Load the prompt cache and metadata if a cache file is provided
|
# Load the prompt cache and metadata if a cache file is provided
|
||||||
@ -195,11 +200,15 @@ def main():
|
|||||||
for eos_token in args.extra_eos_token:
|
for eos_token in args.extra_eos_token:
|
||||||
tokenizer.add_eos_token(eos_token)
|
tokenizer.add_eos_token(eos_token)
|
||||||
|
|
||||||
|
template_kwargs = {}
|
||||||
|
if args.chat_template_config is not None:
|
||||||
|
template_kwargs = json.loads(args.chat_template_config)
|
||||||
|
|
||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
tokenizer.chat_template = tokenizer.default_chat_template
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
elif using_cache:
|
elif using_cache:
|
||||||
tokenizer.chat_template = metadata["chat_template"]
|
tokenizer.chat_template = json.loads(metadata["chat_template"])
|
||||||
|
|
||||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||||
@ -209,8 +218,12 @@ def main():
|
|||||||
else:
|
else:
|
||||||
messages = []
|
messages = []
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
**template_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||||
|
@ -101,6 +101,14 @@ def build_parser():
|
|||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mask-prompt",
|
||||||
|
action="store_true",
|
||||||
|
help="Mask the prompt in the loss when training",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--training-mode",
|
"--training-mode",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -282,12 +282,12 @@ class MoEGate(nn.Module):
|
|||||||
if self.topk_method == "group_limited_greedy":
|
if self.topk_method == "group_limited_greedy":
|
||||||
bsz, seq_len = x.shape[:2]
|
bsz, seq_len = x.shape[:2]
|
||||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||||
group_scores = scores.max(axis=-1)
|
group_scores = scores.max(axis=-1, keepdims=True)
|
||||||
k = self.n_group - self.topk_group
|
k = self.n_group - self.topk_group
|
||||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
scores = mx.put_along_axis(
|
||||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
|
||||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
)
|
||||||
scores = scores.reshape(bsz, seq_len, -1)
|
scores = scores.reshape(bsz, seq_len, -1)
|
||||||
|
|
||||||
k = self.top_k
|
k = self.top_k
|
||||||
|
@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def group_expert_select(
|
||||||
|
gates,
|
||||||
|
e_score_correction_bias,
|
||||||
|
top_k,
|
||||||
|
n_group,
|
||||||
|
topk_group,
|
||||||
|
routed_scaling_factor,
|
||||||
|
norm_topk_prob,
|
||||||
|
):
|
||||||
|
|
||||||
|
k = top_k
|
||||||
|
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||||
|
scores = scores + e_score_correction_bias
|
||||||
|
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
|
||||||
|
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
|
||||||
|
k = n_group - topk_group
|
||||||
|
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||||
|
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
|
||||||
|
scores = mx.flatten(scores, -2, -1)
|
||||||
|
|
||||||
|
k = top_k
|
||||||
|
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||||
|
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||||
|
if top_k > 1 and norm_topk_prob:
|
||||||
|
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||||
|
scores = scores / denominator
|
||||||
|
scores = scores * routed_scaling_factor
|
||||||
|
|
||||||
|
return inds, scores
|
||||||
|
|
||||||
|
|
||||||
class MoEGate(nn.Module):
|
class MoEGate(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -279,38 +311,22 @@ class MoEGate(nn.Module):
|
|||||||
self.norm_topk_prob = config.norm_topk_prob
|
self.norm_topk_prob = config.norm_topk_prob
|
||||||
self.n_routed_experts = config.n_routed_experts
|
self.n_routed_experts = config.n_routed_experts
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
self.topk_method = config.topk_method
|
|
||||||
self.n_group = config.n_group
|
self.n_group = config.n_group
|
||||||
self.topk_group = config.topk_group
|
self.topk_group = config.topk_group
|
||||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||||
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
||||||
|
assert config.topk_method == "noaux_tc", "Unsupported topk method."
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
gates = x @ self.weight.T
|
return group_expert_select(
|
||||||
|
x @ self.weight.T,
|
||||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
self.e_score_correction_bias,
|
||||||
|
self.top_k,
|
||||||
assert self.topk_method == "noaux_tc", "Unsupported topk method."
|
self.n_group,
|
||||||
bsz, seq_len = x.shape[:2]
|
self.topk_group,
|
||||||
scores = scores + self.e_score_correction_bias
|
self.routed_scaling_factor,
|
||||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
self.norm_topk_prob,
|
||||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
|
)
|
||||||
k = self.n_group - self.topk_group
|
|
||||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
|
||||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
|
||||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
|
||||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
|
||||||
scores = scores.reshape(bsz, seq_len, -1)
|
|
||||||
|
|
||||||
k = self.top_k
|
|
||||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
|
||||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
|
||||||
if self.top_k > 1 and self.norm_topk_prob:
|
|
||||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
|
||||||
scores = scores / denominator
|
|
||||||
scores = scores * self.routed_scaling_factor
|
|
||||||
|
|
||||||
return inds, scores
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3MoE(nn.Module):
|
class DeepseekV3MoE(nn.Module):
|
||||||
|
195
llms/mlx_lm/models/granite.py
Normal file
195
llms/mlx_lm/models/granite.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
from .rope_utils import initialize_rope
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
logits_scaling: float
|
||||||
|
attention_multiplier: float
|
||||||
|
embedding_multiplier: float
|
||||||
|
residual_multiplier: float
|
||||||
|
max_position_embeddings: int
|
||||||
|
num_key_value_heads: int
|
||||||
|
attention_bias: bool
|
||||||
|
mlp_bias: bool
|
||||||
|
rope_theta: float
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||||
|
|
||||||
|
self.scale = args.attention_multiplier
|
||||||
|
attention_bias = args.attention_bias
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
|
self.rope = initialize_rope(
|
||||||
|
self.head_dim,
|
||||||
|
args.rope_theta,
|
||||||
|
False,
|
||||||
|
args.rope_scaling,
|
||||||
|
args.max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Any] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = scaled_dot_product_attention(
|
||||||
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
hidden_dim = args.intermediate_size
|
||||||
|
if hasattr(args, "mlp_bias"):
|
||||||
|
mlp_bias = args.mlp_bias
|
||||||
|
else:
|
||||||
|
mlp_bias = False
|
||||||
|
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.residual_multiplier = args.residual_multiplier
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Any] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r * self.residual_multiplier
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r * self.residual_multiplier
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.embedding_multiplier = args.embedding_multiplier
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs) * self.embedding_multiplier
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = GraniteModel(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
self.logits_scaling = args.logits_scaling
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, mask, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out / self.logits_scaling
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
@ -76,7 +76,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
head_dim = args.hidden_size // n_heads
|
head_dim = args.hidden_size // n_heads
|
||||||
self.scale = head_dim**-0.5
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||||
if kv_proj:
|
if kv_proj:
|
||||||
self.k_proj = nn.Linear(
|
self.k_proj = nn.Linear(
|
||||||
@ -107,7 +106,6 @@ class Attention(nn.Module):
|
|||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
queries = self.q_proj(x)
|
queries = self.q_proj(x)
|
||||||
|
|
||||||
if kv_states is None:
|
if kv_states is None:
|
||||||
keys, values = self.k_proj(x), self.v_proj(x)
|
keys, values = self.k_proj(x), self.v_proj(x)
|
||||||
kv_states = keys, values
|
kv_states = keys, values
|
||||||
@ -198,6 +196,9 @@ class DecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = args.hidden_size
|
self.hidden_size = args.hidden_size
|
||||||
self.self_attn = Attention(kv_proj, args)
|
self.self_attn = Attention(kv_proj, args)
|
||||||
|
if args.num_experts == 1:
|
||||||
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
else:
|
||||||
self.mlp = MoeBlock(args)
|
self.mlp = MoeBlock(args)
|
||||||
|
|
||||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
@ -231,7 +232,10 @@ class HunYuanModel(nn.Module):
|
|||||||
assert self.vocab_size > 0
|
assert self.vocab_size > 0
|
||||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
self.layers = [
|
self.layers = [
|
||||||
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
|
DecoderLayer(
|
||||||
|
args=args,
|
||||||
|
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
|
||||||
|
)
|
||||||
for i in range(args.num_hidden_layers)
|
for i in range(args.num_hidden_layers)
|
||||||
]
|
]
|
||||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
@ -251,7 +255,7 @@ class HunYuanModel(nn.Module):
|
|||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
||||||
if i % self.args.cla_share_factor == 0:
|
if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
|
||||||
shared_kv_states = None
|
shared_kv_states = None
|
||||||
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
||||||
|
|
||||||
@ -275,6 +279,29 @@ class Model(nn.Module):
|
|||||||
return self.model.embed_tokens.as_linear(out)
|
return self.model.embed_tokens.as_linear(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
|
||||||
|
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
|
||||||
|
new_weights = {}
|
||||||
|
D = self.args.hidden_size
|
||||||
|
n_kv_heads = self.args.num_key_value_heads
|
||||||
|
n_kv_groups = self.args.num_attention_heads // n_kv_heads
|
||||||
|
head_dim = D // self.args.num_attention_heads
|
||||||
|
for k, v in weights.items():
|
||||||
|
if "qkv_proj" in k:
|
||||||
|
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
|
||||||
|
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
|
||||||
|
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
|
||||||
|
k_new = k.replace("qkv_proj", k_up)
|
||||||
|
new_weights[k_new] = mx.flatten(v_new, 0, 2)
|
||||||
|
elif "gate_and_up_proj" in k:
|
||||||
|
splits = v.split(2, axis=0)
|
||||||
|
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
|
||||||
|
k_new = k.replace("gate_and_up_proj", k_up)
|
||||||
|
new_weights[k_new] = v_new
|
||||||
|
else:
|
||||||
|
new_weights[k] = v
|
||||||
|
weights = new_weights
|
||||||
|
|
||||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||||
return weights
|
return weights
|
||||||
for l in range(self.args.num_hidden_layers):
|
for l in range(self.args.num_hidden_layers):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
|||||||
detokenizer_class,
|
detokenizer_class,
|
||||||
eos_token_ids=eos_token_ids,
|
eos_token_ids=eos_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
|
||||||
|
removed_bos = sequence if sequence[0] != bos else sequence[1:]
|
||||||
|
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
|
import types
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -92,14 +94,24 @@ class ChatDataset:
|
|||||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
def __init__(
|
||||||
self._data = [
|
self,
|
||||||
tokenizer.apply_chat_template(
|
data: List[Dict[str, str]],
|
||||||
d["messages"],
|
tokenizer: PreTrainedTokenizer,
|
||||||
tools=d.get("tools", None),
|
chat_key: str = "messages",
|
||||||
)
|
mask_prompt: bool = False,
|
||||||
for d in data
|
):
|
||||||
]
|
self._data = []
|
||||||
|
for d in data:
|
||||||
|
messages = d[chat_key]
|
||||||
|
tools = d.get("tools", None)
|
||||||
|
tokens = tokenizer.apply_chat_template(messages, tools=tools)
|
||||||
|
if mask_prompt:
|
||||||
|
messages = messages[:-1]
|
||||||
|
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
|
||||||
|
self._data.append((tokens, offset))
|
||||||
|
else:
|
||||||
|
self._data.append(tokens)
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
@ -121,16 +133,25 @@ class CompletionsDataset:
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_key: str,
|
prompt_key: str,
|
||||||
completion_key: str,
|
completion_key: str,
|
||||||
|
mask_prompt: bool,
|
||||||
):
|
):
|
||||||
self._data = [
|
self._data = []
|
||||||
tokenizer.apply_chat_template(
|
for d in data:
|
||||||
|
tokens = tokenizer.apply_chat_template(
|
||||||
[
|
[
|
||||||
{"role": "user", "content": d[prompt_key]},
|
{"role": "user", "content": d[prompt_key]},
|
||||||
{"role": "assistant", "content": d[completion_key]},
|
{"role": "assistant", "content": d[completion_key]},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
for d in data
|
if mask_prompt:
|
||||||
]
|
offset = len(
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": d[prompt_key]}]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._data.append((tokens, offset))
|
||||||
|
else:
|
||||||
|
self._data.append(tokens)
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
@ -139,52 +160,60 @@ class CompletionsDataset:
|
|||||||
return len(self._data)
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
class ConcatenatedDataset:
|
||||||
|
def __init__(self, data: List[Any]):
|
||||||
|
self._data = list(itertools.chain(*data))
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
return self._data[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
def create_dataset(
|
def create_dataset(
|
||||||
args,
|
args,
|
||||||
data,
|
data,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
config,
|
||||||
completion_feature: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
prompt_feature = prompt_feature or "prompt"
|
mask_prompt = getattr(config, "mask_prompt", False)
|
||||||
completion_feature = completion_feature or "completion"
|
prompt_feature = getattr(config, "prompt_feature", "prompt")
|
||||||
|
text_feature = getattr(config, "text_feature", "text")
|
||||||
|
completion_feature = getattr(config, "completion_feature", "completion")
|
||||||
|
chat_feature = getattr(config, "chat_feature", "messages")
|
||||||
sample = data[0]
|
sample = data[0]
|
||||||
|
|
||||||
if args.training_mode == "normal":
|
if args.training_mode == "normal":
|
||||||
if "messages" in sample:
|
if chat_feature in sample:
|
||||||
return ChatDataset(data, tokenizer)
|
return ChatDataset(data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt)
|
||||||
elif prompt_feature in sample and completion_feature in sample:
|
elif prompt_feature in sample and completion_feature in sample:
|
||||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature, mask_prompt)
|
||||||
elif "text" in sample:
|
elif text_feature in sample:
|
||||||
return Dataset(data, tokenizer)
|
if mask_prompt:
|
||||||
|
raise ValueError("Prompt masking not supported for text dataset.")
|
||||||
|
return Dataset(data, tokenizer, text_key=text_feature)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported data format, check the supported formats here:\n"
|
"Unsupported data format, check the supported formats here:\n"
|
||||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||||
)
|
)
|
||||||
elif args.training_mode == "orpo":
|
else:
|
||||||
if "chosen" in sample and "rejected" in sample:
|
if "chosen" in sample and "rejected" in sample:
|
||||||
return ORPODataset(data, tokenizer)
|
return ORPODataset(data, tokenizer)
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Unsupported training mode, check the supported training modes and their formats here:\n"
|
|
||||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#training-modes."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_local_dataset(
|
def load_local_dataset(
|
||||||
args,
|
args,
|
||||||
data_path: Path,
|
data_path: Path,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
config,
|
||||||
completion_feature: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
def load_subset(path):
|
def load_subset(path):
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return []
|
return []
|
||||||
with open(path, "r") as fid:
|
with open(path, "r") as fid:
|
||||||
data = [json.loads(l) for l in fid]
|
data = [json.loads(l) for l in fid]
|
||||||
return create_dataset(args, data, tokenizer, prompt_feature, completion_feature)
|
|
||||||
|
return create_dataset(data, tokenizer, config)
|
||||||
|
|
||||||
names = ("train", "valid", "test")
|
names = ("train", "valid", "test")
|
||||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||||
@ -195,8 +224,7 @@ def load_hf_dataset(
|
|||||||
args,
|
args,
|
||||||
data_id: str,
|
data_id: str,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
config,
|
||||||
completion_feature: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
from datasets import exceptions, load_dataset
|
from datasets import exceptions, load_dataset
|
||||||
|
|
||||||
@ -207,9 +235,7 @@ def load_hf_dataset(
|
|||||||
|
|
||||||
train, valid, test = [
|
train, valid, test = [
|
||||||
(
|
(
|
||||||
create_dataset(
|
create_dataset(args, dataset[n], tokenizer, config)
|
||||||
args, dataset[n], tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
if n in dataset.keys()
|
if n in dataset.keys()
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
@ -225,42 +251,61 @@ def load_hf_dataset(
|
|||||||
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
hf_args = args.hf_dataset
|
def create_hf_dataset(dataset_name, config, split, hf_config):
|
||||||
dataset_name = hf_args["name"]
|
|
||||||
print(f"Loading Hugging Face dataset {dataset_name}.")
|
|
||||||
text_feature = hf_args.get("text_feature")
|
|
||||||
prompt_feature = hf_args.get("prompt_feature")
|
|
||||||
completion_feature = hf_args.get("completion_feature")
|
|
||||||
|
|
||||||
def create_hf_dataset(split: str = None):
|
|
||||||
ds = datasets.load_dataset(
|
ds = datasets.load_dataset(
|
||||||
dataset_name,
|
dataset_name,
|
||||||
split=split,
|
split=split,
|
||||||
**hf_args.get("config", {}),
|
**hf_config,
|
||||||
)
|
|
||||||
if prompt_feature and completion_feature:
|
|
||||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
|
||||||
elif text_feature:
|
|
||||||
return Dataset(ds, tokenizer, text_key=text_feature)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Specify either a prompt and completion feature or a text "
|
|
||||||
"feature for the Hugging Face dataset."
|
|
||||||
)
|
)
|
||||||
|
return create_dataset(ds, tokenizer, config)
|
||||||
|
|
||||||
|
dataset_collection = args.hf_dataset
|
||||||
|
if isinstance(dataset_collection, dict):
|
||||||
|
dataset_collection = [dataset_collection]
|
||||||
|
|
||||||
|
collection = []
|
||||||
|
for ds in dataset_collection:
|
||||||
|
ds_name = ds["name"]
|
||||||
|
print(f"Loading Hugging Face dataset {ds_name}.")
|
||||||
|
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
|
||||||
|
config = types.SimpleNamespace(**ds)
|
||||||
|
hf_config = ds.get("config", {})
|
||||||
if args.train:
|
if args.train:
|
||||||
train_split = hf_args.get("train_split", "train[:80%]")
|
train_split = ds.get("train_split", "train[:80%]")
|
||||||
valid_split = hf_args.get("valid_split", "train[-10%:]")
|
valid_split = ds.get("valid_split", "train[-10%:]")
|
||||||
train = create_hf_dataset(split=train_split)
|
train = create_hf_dataset(
|
||||||
valid = create_hf_dataset(split=valid_split)
|
ds_name,
|
||||||
|
config,
|
||||||
|
train_split,
|
||||||
|
hf_config,
|
||||||
|
)
|
||||||
|
valid = create_hf_dataset(
|
||||||
|
ds_name,
|
||||||
|
config,
|
||||||
|
valid_split,
|
||||||
|
hf_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
train, valid = [], []
|
train, valid = [], []
|
||||||
|
|
||||||
if args.test:
|
if args.test:
|
||||||
test = create_hf_dataset(split=hf_args.get("test_split"))
|
test_split = ds.get("test_split")
|
||||||
|
test = create_hf_dataset(
|
||||||
|
ds_name,
|
||||||
|
config,
|
||||||
|
test_split,
|
||||||
|
hf_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
test = []
|
test = []
|
||||||
|
|
||||||
return train, valid, test
|
collection.append((train, valid, test))
|
||||||
|
|
||||||
|
if len(collection) == 1:
|
||||||
|
return collection[0]
|
||||||
|
|
||||||
|
# Otherwise concatenate them
|
||||||
|
return tuple(map(ConcatenatedDataset, zip(*collection)))
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
@ -268,18 +313,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||||
else:
|
else:
|
||||||
data_path = Path(args.data)
|
data_path = Path(args.data)
|
||||||
|
|
||||||
prompt_feature = getattr(args, "prompt_feature", None)
|
|
||||||
completion_feature = getattr(args, "completion_feature", None)
|
|
||||||
if data_path.exists():
|
if data_path.exists():
|
||||||
train, valid, test = load_local_dataset(
|
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
|
||||||
args, data_path, tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(f"Loading Hugging Face dataset {args.data}.")
|
print(f"Loading Hugging Face dataset {args.data}.")
|
||||||
train, valid, test = load_hf_dataset(
|
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
||||||
args, args.data, tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.train and len(train) == 0:
|
if args.train and len(train) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -5,13 +5,16 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from .datasets import CompletionsDataset
|
||||||
|
|
||||||
|
|
||||||
def grad_checkpoint(layer):
|
def grad_checkpoint(layer):
|
||||||
@ -63,20 +66,30 @@ class TrainingArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def default_loss(model, batch, lengths):
|
||||||
|
inputs = batch[:, :-1]
|
||||||
|
targets = batch[:, 1:]
|
||||||
|
|
||||||
logits = model(inputs)
|
logits = model(inputs)
|
||||||
logits = logits.astype(mx.float32)
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
steps = mx.arange(1, targets.shape[1] + 1)
|
||||||
|
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||||
|
|
||||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||||
ntoks = length_mask.sum()
|
ntoks = mask.sum()
|
||||||
ce = ce.sum() / ntoks
|
ce = ce.sum() / ntoks
|
||||||
|
|
||||||
return ce, ntoks
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_batches(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
batch_size,
|
||||||
|
max_seq_length,
|
||||||
|
train=False,
|
||||||
|
):
|
||||||
# Sort by length:
|
# Sort by length:
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
indices = np.random.permutation(len(batch_idx))
|
indices = np.random.permutation(len(batch_idx))
|
||||||
for i in indices:
|
for i in indices:
|
||||||
batch = [dataset[j] for j in batch_idx[i]]
|
batch = [dataset[j] for j in batch_idx[i]]
|
||||||
|
if len(batch[0]) == 2:
|
||||||
|
batch, offsets = zip(*batch)
|
||||||
|
else:
|
||||||
|
offsets = [0] * len(batch)
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
if max(lengths) > max_seq_length:
|
if max(lengths) > max_seq_length:
|
||||||
print(
|
print(
|
||||||
@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
truncated_length # Update lengths to match truncated lengths
|
truncated_length # Update lengths to match truncated lengths
|
||||||
)
|
)
|
||||||
batch = mx.array(batch_arr)
|
batch = mx.array(batch_arr)
|
||||||
|
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
|
@ -94,6 +94,7 @@ def linear_to_lora_layers(
|
|||||||
"phimoe",
|
"phimoe",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
|
"granite",
|
||||||
"helium",
|
"helium",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"cohere",
|
"cohere",
|
||||||
|
@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||||
|
|
||||||
def test_hf(self):
|
def test_hf(self):
|
||||||
args = types.SimpleNamespace(
|
hf_args = {
|
||||||
hf_dataset={
|
|
||||||
"name": "billsum",
|
"name": "billsum",
|
||||||
"prompt_feature": "text",
|
"prompt_feature": "text",
|
||||||
"completion_feature": "summary",
|
"completion_feature": "summary",
|
||||||
"train_split": "train[:2%]",
|
"train_split": "train[:2%]",
|
||||||
"valid_split": "train[-2%:]",
|
"valid_split": "train[-2%:]",
|
||||||
},
|
}
|
||||||
|
args = types.SimpleNamespace(
|
||||||
|
hf_dataset=hf_args,
|
||||||
test=False,
|
test=False,
|
||||||
train=True,
|
train=True,
|
||||||
)
|
)
|
||||||
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
|
|||||||
self.assertTrue(len(valid[0]) > 0)
|
self.assertTrue(len(valid[0]) > 0)
|
||||||
self.assertEqual(len(test), 0)
|
self.assertEqual(len(test), 0)
|
||||||
|
|
||||||
|
args = types.SimpleNamespace(
|
||||||
|
hf_dataset=[hf_args, hf_args],
|
||||||
|
test=False,
|
||||||
|
train=True,
|
||||||
|
)
|
||||||
|
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
|
||||||
|
self.assertEqual(2 * len(train), len(train_double))
|
||||||
|
self.assertEqual(2 * len(valid), len(valid_double))
|
||||||
|
self.assertEqual(2 * len(test), len(test_double))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -134,9 +134,7 @@ def find_alignment(
|
|||||||
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
||||||
# consider only the logits associated with predicting text
|
# consider only the logits associated with predicting text
|
||||||
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
||||||
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
|
token_probs = mx.softmax(sampled_logits, precise=True, axis=-1)
|
||||||
sampled_logits.dtype
|
|
||||||
)
|
|
||||||
text_token_probs = mx.take_along_axis(
|
text_token_probs = mx.take_along_axis(
|
||||||
token_probs, mx.array(text_tokens)[:, None], axis=1
|
token_probs, mx.array(text_tokens)[:, None], axis=1
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
@ -144,10 +142,11 @@ def find_alignment(
|
|||||||
|
|
||||||
# heads * tokens * frames
|
# heads * tokens * frames
|
||||||
weights = mx.stack(
|
weights = mx.stack(
|
||||||
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
|
[cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()]
|
||||||
)
|
)
|
||||||
weights = weights[:, :, : num_frames // 2]
|
weights = weights[:, :, : num_frames // 2]
|
||||||
weights = mx.softmax(weights * qk_scale, axis=-1)
|
weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
|
||||||
|
weights = weights.astype(mx.float32)
|
||||||
mean = mx.mean(weights, axis=-2, keepdims=True)
|
mean = mx.mean(weights, axis=-2, keepdims=True)
|
||||||
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
||||||
weights = (weights - mean) / std
|
weights = (weights - mean) / std
|
||||||
|
@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
w = mx.softmax(qk, axis=-1, precise=True)
|
w = mx.softmax(qk, axis=-1, precise=True)
|
||||||
out = (w @ v).transpose(0, 2, 1, 3)
|
out = (w @ v).transpose(0, 2, 1, 3)
|
||||||
out = out.reshape(n_batch, n_ctx, n_state)
|
out = out.reshape(n_batch, n_ctx, n_state)
|
||||||
return out, qk.astype(mx.float32)
|
return out, qk
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user