mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
put offset in prompt, simplify
This commit is contained in:
parent
6ace6dc6b2
commit
6e9542a934
@ -12,7 +12,7 @@ import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper, no_bos_or_eos
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import (
|
||||
@ -62,7 +62,6 @@ CONFIG_DEFAULTS = {
|
||||
"grad_checkpoint": False,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"response_template": None,
|
||||
}
|
||||
|
||||
|
||||
@ -97,10 +96,9 @@ def build_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-inputs",
|
||||
dest="mask_inputs",
|
||||
"--mask-prompt",
|
||||
action="store_true",
|
||||
help="Whether to mask the inputs when training. Default is False.",
|
||||
help="Mask the prompt in the loss when training",
|
||||
default=False,
|
||||
)
|
||||
|
||||
@ -182,13 +180,6 @@ def train_model(
|
||||
valid_set,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
from .tuner.trainer import (
|
||||
default_loss,
|
||||
input_masked_loss,
|
||||
iterate_batches,
|
||||
iterate_completion_batches,
|
||||
)
|
||||
|
||||
model.freeze()
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||
@ -217,17 +208,6 @@ def train_model(
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
if isinstance(args.response_template, str):
|
||||
response_generation_tokens = tokenizer.encode(
|
||||
args.response_template, add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
if not all([item.isinstance(int) for item in args.response_template]):
|
||||
raise ValueError(
|
||||
"Response template must be a list of integers if it is not a string."
|
||||
)
|
||||
response_generation_tokens = args.response_template
|
||||
|
||||
# init training args
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
@ -239,9 +219,6 @@ def train_model(
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
response_generation_tokens=no_bos_or_eos(
|
||||
response_generation_tokens, tokenizer.bos_token_id, tokenizer.eos_token_id
|
||||
),
|
||||
)
|
||||
|
||||
model.train()
|
||||
@ -251,9 +228,6 @@ def train_model(
|
||||
)
|
||||
)
|
||||
|
||||
if args.mask_inputs:
|
||||
print("Masking inputs..")
|
||||
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
@ -263,10 +237,6 @@ def train_model(
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
iterate_batches=(
|
||||
iterate_completion_batches if args.mask_inputs else iterate_batches
|
||||
),
|
||||
loss=input_masked_loss if args.mask_inputs else default_loss,
|
||||
)
|
||||
|
||||
|
||||
|
@ -40,14 +40,19 @@ class ChatDataset:
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
chat_key: str = "messages",
|
||||
mask_prompt: bool = False,
|
||||
):
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
d[chat_key],
|
||||
tools=d.get("tools", None),
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
self._data = []
|
||||
for d in data:
|
||||
messages = d[chat_key]
|
||||
tools = d.get("tools", None)
|
||||
tokens = tokenizer.apply_chat_template(messages, tools=tools)
|
||||
if mask_prompt:
|
||||
messages = messages[:-1]
|
||||
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
|
||||
self._data.append((tokens, offset))
|
||||
else:
|
||||
self._data.append(tokens)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
@ -69,16 +74,25 @@ class CompletionsDataset:
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str,
|
||||
completion_key: str,
|
||||
mask_prompt: bool,
|
||||
):
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
self._data = []
|
||||
for d in data:
|
||||
tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[completion_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
if mask_prompt:
|
||||
offset = len(
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": d[prompt_key]}]
|
||||
)
|
||||
)
|
||||
self._data.append((tokens, offset))
|
||||
else:
|
||||
self._data.append(tokens)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
@ -101,17 +115,21 @@ class ConcatenatedDataset:
|
||||
def create_dataset(
|
||||
data,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
config: Dict,
|
||||
):
|
||||
prompt_feature = prompt_feature or "prompt"
|
||||
completion_feature = completion_feature or "completion"
|
||||
mask_prompt = getattr(config, "mask_prompt", False)
|
||||
prompt_feature = getattr(config, "prompt_feature", "prompt")
|
||||
completion_feature = getattr(config, "completion_feature", "completion")
|
||||
sample = data[0]
|
||||
if "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
return ChatDataset(data, tokenizer, mask_prompt=mask_prompt)
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
return CompletionsDataset(
|
||||
data, tokenizer, prompt_feature, completion_feature, mask_prompt
|
||||
)
|
||||
elif "text" in sample:
|
||||
if mask_prompt:
|
||||
raise ValueError("Prompt masking not supported for text dataset.")
|
||||
return Dataset(data, tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -123,15 +141,14 @@ def create_dataset(
|
||||
def load_local_dataset(
|
||||
data_path: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
config: Dict,
|
||||
):
|
||||
def load_subset(path):
|
||||
if not path.exists():
|
||||
return []
|
||||
with open(path, "r") as fid:
|
||||
data = [json.loads(l) for l in fid]
|
||||
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
return create_dataset(data, tokenizer, config)
|
||||
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||
@ -141,8 +158,7 @@ def load_local_dataset(
|
||||
def load_hf_dataset(
|
||||
data_id: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
config: Dict,
|
||||
):
|
||||
from datasets import exceptions, load_dataset
|
||||
|
||||
@ -153,9 +169,7 @@ def load_hf_dataset(
|
||||
|
||||
train, valid, test = [
|
||||
(
|
||||
create_dataset(
|
||||
dataset[n], tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
create_dataset(dataset[n], tokenizer, config)
|
||||
if n in dataset.keys()
|
||||
else []
|
||||
)
|
||||
@ -186,10 +200,16 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
**config,
|
||||
)
|
||||
if prompt_feature and completion_feature:
|
||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||
return CompletionsDataset(
|
||||
data, tokenizer, prompt_feature, completion_feature, mask_prompt
|
||||
)
|
||||
elif chat_feature:
|
||||
return ChatDataset(ds, tokenizer, chat_key=chat_feature)
|
||||
return ChatDataset(
|
||||
ds, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
|
||||
)
|
||||
elif text_feature:
|
||||
if mask_prompt:
|
||||
raise ValueError("Prompt masking not supported for text dataset.")
|
||||
return Dataset(ds, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -262,18 +282,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||
else:
|
||||
data_path = Path(args.data)
|
||||
|
||||
prompt_feature = getattr(args, "prompt_feature", None)
|
||||
completion_feature = getattr(args, "completion_feature", None)
|
||||
if data_path.exists():
|
||||
train, valid, test = load_local_dataset(
|
||||
data_path, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
train, valid, test = load_local_dataset(data_path, tokenizer, args)
|
||||
else:
|
||||
print(f"Loading Hugging Face dataset {args.data}.")
|
||||
train, valid, test = load_hf_dataset(
|
||||
args.data, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
||||
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
|
@ -64,39 +64,20 @@ class TrainingArgs:
|
||||
default=False,
|
||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||
)
|
||||
response_generation_tokens: Optional[List[int]] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of token ids that mark the beginning of the response"},
|
||||
)
|
||||
|
||||
|
||||
def input_masked_loss(model, inputs, response_prefix_lengths, lengths):
|
||||
shifted_inputs = inputs[:, :-1]
|
||||
shifted_labels = inputs[:, 1:]
|
||||
logits = model(shifted_inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
def default_loss(model, batch, lengths):
|
||||
inputs = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
|
||||
mask_width = shifted_inputs.shape[1]
|
||||
token_indices = mx.arange(mask_width)[None, :]
|
||||
mask = mx.logical_and(
|
||||
token_indices >= response_prefix_lengths[:, None],
|
||||
token_indices < lengths[:, None],
|
||||
)
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, shifted_labels) * mask
|
||||
ntoks = mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
steps = mx.arange(1, targets.shape[1] + 1)
|
||||
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||
ntoks = length_mask.sum()
|
||||
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||
ntoks = mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
|
||||
return ce, ntoks
|
||||
@ -112,96 +93,12 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]:
|
||||
return ind, ind + small_list_length - 1
|
||||
|
||||
|
||||
def iterate_completion_batches(
|
||||
dataset: CompletionsDataset,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
train: bool = False,
|
||||
response_generation_tokens: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
|
||||
and returns the lengths of input tokens as well as that of the full sequences.
|
||||
"""
|
||||
idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i]))
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size}"
|
||||
f" examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
# If running in distributed mode (N machines) then each one should skip N-1
|
||||
# samples
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
# Make the batches:
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
response_prefix_lengths = []
|
||||
batch = []
|
||||
for j in batch_idx[i]:
|
||||
full_sequence = dataset.get_item(j, tokenize=True)
|
||||
if full_sequence[-1] != tokenizer.eos_token_id:
|
||||
full_sequence.append(tokenizer.eos_token_id)
|
||||
batch.append(full_sequence)
|
||||
if len(response_generation_tokens) > 1:
|
||||
response_marker_begin, response_marker_end = contains(
|
||||
response_generation_tokens, full_sequence
|
||||
)
|
||||
response_prefix_lengths.append(response_marker_end + 1)
|
||||
else:
|
||||
response_marker_begin = full_sequence.index(
|
||||
response_generation_tokens[0]
|
||||
)
|
||||
response_prefix_lengths.append(response_marker_begin + 1)
|
||||
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
|
||||
"Consider pre-splitting your data to save memory."
|
||||
)
|
||||
|
||||
# Pad to the nearest multiple of 8 or the maximum length
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
for j in range(batch_size // step):
|
||||
response_prefix_length = response_prefix_lengths[j]
|
||||
truncated_length = min(lengths[j], max_seq_length)
|
||||
batch_arr[j, response_prefix_length:truncated_length] = batch[j][
|
||||
response_prefix_length:truncated_length
|
||||
]
|
||||
lengths[j] = (
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
|
||||
yield mx.array(batch_arr), mx.array(response_prefix_lengths), mx.array(
|
||||
lengths
|
||||
)
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def iterate_batches(
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
max_seq_length,
|
||||
train=False,
|
||||
response_generation_tokens=None,
|
||||
):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
@ -227,6 +124,10 @@ def iterate_batches(
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
if len(batch[0]) == 2:
|
||||
batch, offsets = zip(*batch)
|
||||
else:
|
||||
offsets = [0] * len(batch)
|
||||
lengths = [len(x) for x in batch]
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
@ -249,8 +150,7 @@ def iterate_batches(
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
batch = mx.array(batch_arr)
|
||||
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||
|
||||
if not train:
|
||||
break
|
||||
@ -265,7 +165,6 @@ def evaluate(
|
||||
max_seq_length=2048,
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
response_generation_tokens: Optional[List[int]] = None,
|
||||
):
|
||||
all_losses = mx.array(0.0)
|
||||
ntokens = mx.array(0)
|
||||
@ -279,7 +178,6 @@ def evaluate(
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
response_generation_tokens=response_generation_tokens,
|
||||
),
|
||||
):
|
||||
losses, toks = loss(model, *batch)
|
||||
@ -355,7 +253,6 @@ def train(
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
response_generation_tokens=args.response_generation_tokens,
|
||||
),
|
||||
):
|
||||
# Report validation loss if needed, the first validation loss
|
||||
@ -371,7 +268,6 @@ def train(
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches,
|
||||
response_generation_tokens=args.response_generation_tokens,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
|
Loading…
Reference in New Issue
Block a user