put offset in prompt, simplify

This commit is contained in:
Awni Hannun 2025-02-09 17:31:23 -08:00
parent 6ace6dc6b2
commit 6e9542a934
3 changed files with 64 additions and 185 deletions

View File

@ -12,7 +12,7 @@ import mlx.optimizers as optim
import numpy as np
import yaml
from .tokenizer_utils import TokenizerWrapper, no_bos_or_eos
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
@ -62,7 +62,6 @@ CONFIG_DEFAULTS = {
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"response_template": None,
}
@ -97,10 +96,9 @@ def build_parser():
)
parser.add_argument(
"--mask-inputs",
dest="mask_inputs",
"--mask-prompt",
action="store_true",
help="Whether to mask the inputs when training. Default is False.",
help="Mask the prompt in the loss when training",
default=False,
)
@ -182,13 +180,6 @@ def train_model(
valid_set,
training_callback: TrainingCallback = None,
):
from .tuner.trainer import (
default_loss,
input_masked_loss,
iterate_batches,
iterate_completion_batches,
)
model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
@ -217,17 +208,6 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
if isinstance(args.response_template, str):
response_generation_tokens = tokenizer.encode(
args.response_template, add_special_tokens=False
)
else:
if not all([item.isinstance(int) for item in args.response_template]):
raise ValueError(
"Response template must be a list of integers if it is not a string."
)
response_generation_tokens = args.response_template
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
@ -239,9 +219,6 @@ def train_model(
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
response_generation_tokens=no_bos_or_eos(
response_generation_tokens, tokenizer.bos_token_id, tokenizer.eos_token_id
),
)
model.train()
@ -251,9 +228,6 @@ def train_model(
)
)
if args.mask_inputs:
print("Masking inputs..")
# Train model
train(
model=model,
@ -263,10 +237,6 @@ def train_model(
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
iterate_batches=(
iterate_completion_batches if args.mask_inputs else iterate_batches
),
loss=input_masked_loss if args.mask_inputs else default_loss,
)

View File

@ -40,14 +40,19 @@ class ChatDataset:
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
chat_key: str = "messages",
mask_prompt: bool = False,
):
self._data = [
tokenizer.apply_chat_template(
d[chat_key],
tools=d.get("tools", None),
)
for d in data
]
self._data = []
for d in data:
messages = d[chat_key]
tools = d.get("tools", None)
tokens = tokenizer.apply_chat_template(messages, tools=tools)
if mask_prompt:
messages = messages[:-1]
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int):
return self._data[idx]
@ -69,16 +74,25 @@ class CompletionsDataset:
tokenizer: PreTrainedTokenizer,
prompt_key: str,
completion_key: str,
mask_prompt: bool,
):
self._data = [
tokenizer.apply_chat_template(
self._data = []
for d in data:
tokens = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
if mask_prompt:
offset = len(
tokenizer.apply_chat_template(
[{"role": "user", "content": d[prompt_key]}]
)
)
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int):
return self._data[idx]
@ -101,17 +115,21 @@ class ConcatenatedDataset:
def create_dataset(
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config: Dict,
):
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
mask_prompt = getattr(config, "mask_prompt", False)
prompt_feature = getattr(config, "prompt_feature", "prompt")
completion_feature = getattr(config, "completion_feature", "completion")
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
return ChatDataset(data, tokenizer, mask_prompt=mask_prompt)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
return CompletionsDataset(
data, tokenizer, prompt_feature, completion_feature, mask_prompt
)
elif "text" in sample:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(data, tokenizer)
else:
raise ValueError(
@ -123,15 +141,14 @@ def create_dataset(
def load_local_dataset(
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config: Dict,
):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
return create_dataset(data, tokenizer, config)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
@ -141,8 +158,7 @@ def load_local_dataset(
def load_hf_dataset(
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config: Dict,
):
from datasets import exceptions, load_dataset
@ -153,9 +169,7 @@ def load_hf_dataset(
train, valid, test = [
(
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
)
create_dataset(dataset[n], tokenizer, config)
if n in dataset.keys()
else []
)
@ -186,10 +200,16 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
**config,
)
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
return CompletionsDataset(
data, tokenizer, prompt_feature, completion_feature, mask_prompt
)
elif chat_feature:
return ChatDataset(ds, tokenizer, chat_key=chat_feature)
return ChatDataset(
ds, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
)
elif text_feature:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
@ -262,18 +282,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists():
train, valid, test = load_local_dataset(
data_path, tokenizer, prompt_feature, completion_feature
)
train, valid, test = load_local_dataset(data_path, tokenizer, args)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(
args.data, tokenizer, prompt_feature, completion_feature
)
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
if args.train and len(train) == 0:
raise ValueError(

View File

@ -64,39 +64,20 @@ class TrainingArgs:
default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."},
)
response_generation_tokens: Optional[List[int]] = field(
default_factory=list,
metadata={"help": "List of token ids that mark the beginning of the response"},
)
def input_masked_loss(model, inputs, response_prefix_lengths, lengths):
shifted_inputs = inputs[:, :-1]
shifted_labels = inputs[:, 1:]
logits = model(shifted_inputs)
logits = logits.astype(mx.float32)
def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
mask_width = shifted_inputs.shape[1]
token_indices = mx.arange(mask_width)[None, :]
mask = mx.logical_and(
token_indices >= response_prefix_lengths[:, None],
token_indices < lengths[:, None],
)
ce = nn.losses.cross_entropy(logits, shifted_labels) * mask
ntoks = mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def default_loss(model, inputs, targets, lengths):
logits = model(inputs)
logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
@ -112,96 +93,12 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]:
return ind, ind + small_list_length - 1
def iterate_completion_batches(
dataset: CompletionsDataset,
tokenizer: PreTrainedTokenizer,
batch_size: int,
max_seq_length: int,
train: bool = False,
response_generation_tokens: Optional[List[int]] = None,
):
"""
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
and returns the lengths of input tokens as well as that of the full sequences.
"""
idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches:
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
response_prefix_lengths = []
batch = []
for j in batch_idx[i]:
full_sequence = dataset.get_item(j, tokenize=True)
if full_sequence[-1] != tokenizer.eos_token_id:
full_sequence.append(tokenizer.eos_token_id)
batch.append(full_sequence)
if len(response_generation_tokens) > 1:
response_marker_begin, response_marker_end = contains(
response_generation_tokens, full_sequence
)
response_prefix_lengths.append(response_marker_end + 1)
else:
response_marker_begin = full_sequence.index(
response_generation_tokens[0]
)
response_prefix_lengths.append(response_marker_begin + 1)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
"Consider pre-splitting your data to save memory."
)
# Pad to the nearest multiple of 8 or the maximum length
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size // step):
response_prefix_length = response_prefix_lengths[j]
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, response_prefix_length:truncated_length] = batch[j][
response_prefix_length:truncated_length
]
lengths[j] = (
truncated_length # Update lengths to match truncated lengths
)
yield mx.array(batch_arr), mx.array(response_prefix_lengths), mx.array(
lengths
)
if not train:
break
def iterate_batches(
dataset,
tokenizer,
batch_size,
max_seq_length,
train=False,
response_generation_tokens=None,
):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
@ -227,6 +124,10 @@ def iterate_batches(
indices = np.random.permutation(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
@ -249,8 +150,7 @@ def iterate_batches(
truncated_length # Update lengths to match truncated lengths
)
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
yield batch, mx.array(list(zip(offsets, lengths)))
if not train:
break
@ -265,7 +165,6 @@ def evaluate(
max_seq_length=2048,
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
response_generation_tokens: Optional[List[int]] = None,
):
all_losses = mx.array(0.0)
ntokens = mx.array(0)
@ -279,7 +178,6 @@ def evaluate(
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
response_generation_tokens=response_generation_tokens,
),
):
losses, toks = loss(model, *batch)
@ -355,7 +253,6 @@ def train(
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
response_generation_tokens=args.response_generation_tokens,
),
):
# Report validation loss if needed, the first validation loss
@ -371,7 +268,6 @@ def train(
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
response_generation_tokens=args.response_generation_tokens,
)
val_time = time.perf_counter() - stop
if rank == 0: