put offset in prompt, simplify

This commit is contained in:
Awni Hannun 2025-02-09 17:31:23 -08:00
parent 6ace6dc6b2
commit 6e9542a934
3 changed files with 64 additions and 185 deletions

View File

@ -12,7 +12,7 @@ import mlx.optimizers as optim
import numpy as np import numpy as np
import yaml import yaml
from .tokenizer_utils import TokenizerWrapper, no_bos_or_eos from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import ( from .tuner.utils import (
@ -62,7 +62,6 @@ CONFIG_DEFAULTS = {
"grad_checkpoint": False, "grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"response_template": None,
} }
@ -97,10 +96,9 @@ def build_parser():
) )
parser.add_argument( parser.add_argument(
"--mask-inputs", "--mask-prompt",
dest="mask_inputs",
action="store_true", action="store_true",
help="Whether to mask the inputs when training. Default is False.", help="Mask the prompt in the loss when training",
default=False, default=False,
) )
@ -182,13 +180,6 @@ def train_model(
valid_set, valid_set,
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
from .tuner.trainer import (
default_loss,
input_masked_loss,
iterate_batches,
iterate_completion_batches,
)
model.freeze() model.freeze()
if args.fine_tune_type == "full": if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]: for l in model.layers[-min(args.num_layers, 0) :]:
@ -217,17 +208,6 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors" adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json") save_config(vars(args), adapter_path / "adapter_config.json")
if isinstance(args.response_template, str):
response_generation_tokens = tokenizer.encode(
args.response_template, add_special_tokens=False
)
else:
if not all([item.isinstance(int) for item in args.response_template]):
raise ValueError(
"Response template must be a list of integers if it is not a string."
)
response_generation_tokens = args.response_template
# init training args # init training args
training_args = TrainingArgs( training_args = TrainingArgs(
batch_size=args.batch_size, batch_size=args.batch_size,
@ -239,9 +219,6 @@ def train_model(
adapter_file=adapter_file, adapter_file=adapter_file,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint, grad_checkpoint=args.grad_checkpoint,
response_generation_tokens=no_bos_or_eos(
response_generation_tokens, tokenizer.bos_token_id, tokenizer.eos_token_id
),
) )
model.train() model.train()
@ -251,9 +228,6 @@ def train_model(
) )
) )
if args.mask_inputs:
print("Masking inputs..")
# Train model # Train model
train( train(
model=model, model=model,
@ -263,10 +237,6 @@ def train_model(
train_dataset=train_set, train_dataset=train_set,
val_dataset=valid_set, val_dataset=valid_set,
training_callback=training_callback, training_callback=training_callback,
iterate_batches=(
iterate_completion_batches if args.mask_inputs else iterate_batches
),
loss=input_masked_loss if args.mask_inputs else default_loss,
) )

View File

@ -40,14 +40,19 @@ class ChatDataset:
data: List[Dict[str, str]], data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
chat_key: str = "messages", chat_key: str = "messages",
mask_prompt: bool = False,
): ):
self._data = [ self._data = []
tokenizer.apply_chat_template( for d in data:
d[chat_key], messages = d[chat_key]
tools=d.get("tools", None), tools = d.get("tools", None)
) tokens = tokenizer.apply_chat_template(messages, tools=tools)
for d in data if mask_prompt:
] messages = messages[:-1]
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx] return self._data[idx]
@ -69,16 +74,25 @@ class CompletionsDataset:
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_key: str, prompt_key: str,
completion_key: str, completion_key: str,
mask_prompt: bool,
): ):
self._data = [ self._data = []
tokenizer.apply_chat_template( for d in data:
tokens = tokenizer.apply_chat_template(
[ [
{"role": "user", "content": d[prompt_key]}, {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]}, {"role": "assistant", "content": d[completion_key]},
], ],
) )
for d in data if mask_prompt:
] offset = len(
tokenizer.apply_chat_template(
[{"role": "user", "content": d[prompt_key]}]
)
)
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx] return self._data[idx]
@ -101,17 +115,21 @@ class ConcatenatedDataset:
def create_dataset( def create_dataset(
data, data,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config: Dict,
completion_feature: Optional[str] = None,
): ):
prompt_feature = prompt_feature or "prompt" mask_prompt = getattr(config, "mask_prompt", False)
completion_feature = completion_feature or "completion" prompt_feature = getattr(config, "prompt_feature", "prompt")
completion_feature = getattr(config, "completion_feature", "completion")
sample = data[0] sample = data[0]
if "messages" in sample: if "messages" in sample:
return ChatDataset(data, tokenizer) return ChatDataset(data, tokenizer, mask_prompt=mask_prompt)
elif prompt_feature in sample and completion_feature in sample: elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(
data, tokenizer, prompt_feature, completion_feature, mask_prompt
)
elif "text" in sample: elif "text" in sample:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(data, tokenizer) return Dataset(data, tokenizer)
else: else:
raise ValueError( raise ValueError(
@ -123,15 +141,14 @@ def create_dataset(
def load_local_dataset( def load_local_dataset(
data_path: Path, data_path: Path,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config: Dict,
completion_feature: Optional[str] = None,
): ):
def load_subset(path): def load_subset(path):
if not path.exists(): if not path.exists():
return [] return []
with open(path, "r") as fid: with open(path, "r") as fid:
data = [json.loads(l) for l in fid] data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer, prompt_feature, completion_feature) return create_dataset(data, tokenizer, config)
names = ("train", "valid", "test") names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
@ -141,8 +158,7 @@ def load_local_dataset(
def load_hf_dataset( def load_hf_dataset(
data_id: str, data_id: str,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config: Dict,
completion_feature: Optional[str] = None,
): ):
from datasets import exceptions, load_dataset from datasets import exceptions, load_dataset
@ -153,9 +169,7 @@ def load_hf_dataset(
train, valid, test = [ train, valid, test = [
( (
create_dataset( create_dataset(dataset[n], tokenizer, config)
dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys() if n in dataset.keys()
else [] else []
) )
@ -186,10 +200,16 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
**config, **config,
) )
if prompt_feature and completion_feature: if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(
data, tokenizer, prompt_feature, completion_feature, mask_prompt
)
elif chat_feature: elif chat_feature:
return ChatDataset(ds, tokenizer, chat_key=chat_feature) return ChatDataset(
ds, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
)
elif text_feature: elif text_feature:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(ds, tokenizer, text_key=text_feature) return Dataset(ds, tokenizer, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
@ -262,18 +282,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists(): if data_path.exists():
train, valid, test = load_local_dataset( train, valid, test = load_local_dataset(data_path, tokenizer, args)
data_path, tokenizer, prompt_feature, completion_feature
)
else: else:
print(f"Loading Hugging Face dataset {args.data}.") print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset( train, valid, test = load_hf_dataset(args.data, tokenizer, args)
args.data, tokenizer, prompt_feature, completion_feature
)
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(

View File

@ -64,39 +64,20 @@ class TrainingArgs:
default=False, default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."}, metadata={"help": "Use gradient checkpointing to reduce memory use."},
) )
response_generation_tokens: Optional[List[int]] = field(
default_factory=list,
metadata={"help": "List of token ids that mark the beginning of the response"},
)
def input_masked_loss(model, inputs, response_prefix_lengths, lengths): def default_loss(model, batch, lengths):
shifted_inputs = inputs[:, :-1] inputs = batch[:, :-1]
shifted_labels = inputs[:, 1:] targets = batch[:, 1:]
logits = model(shifted_inputs)
logits = logits.astype(mx.float32)
mask_width = shifted_inputs.shape[1]
token_indices = mx.arange(mask_width)[None, :]
mask = mx.logical_and(
token_indices >= response_prefix_lengths[:, None],
token_indices < lengths[:, None],
)
ce = nn.losses.cross_entropy(logits, shifted_labels) * mask
ntoks = mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def default_loss(model, inputs, targets, lengths):
logits = model(inputs) logits = model(inputs)
logits = logits.astype(mx.float32) logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = length_mask.sum() ntoks = mask.sum()
ce = ce.sum() / ntoks ce = ce.sum() / ntoks
return ce, ntoks return ce, ntoks
@ -112,96 +93,12 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]:
return ind, ind + small_list_length - 1 return ind, ind + small_list_length - 1
def iterate_completion_batches(
dataset: CompletionsDataset,
tokenizer: PreTrainedTokenizer,
batch_size: int,
max_seq_length: int,
train: bool = False,
response_generation_tokens: Optional[List[int]] = None,
):
"""
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
and returns the lengths of input tokens as well as that of the full sequences.
"""
idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches:
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
response_prefix_lengths = []
batch = []
for j in batch_idx[i]:
full_sequence = dataset.get_item(j, tokenize=True)
if full_sequence[-1] != tokenizer.eos_token_id:
full_sequence.append(tokenizer.eos_token_id)
batch.append(full_sequence)
if len(response_generation_tokens) > 1:
response_marker_begin, response_marker_end = contains(
response_generation_tokens, full_sequence
)
response_prefix_lengths.append(response_marker_end + 1)
else:
response_marker_begin = full_sequence.index(
response_generation_tokens[0]
)
response_prefix_lengths.append(response_marker_begin + 1)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
"Consider pre-splitting your data to save memory."
)
# Pad to the nearest multiple of 8 or the maximum length
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size // step):
response_prefix_length = response_prefix_lengths[j]
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, response_prefix_length:truncated_length] = batch[j][
response_prefix_length:truncated_length
]
lengths[j] = (
truncated_length # Update lengths to match truncated lengths
)
yield mx.array(batch_arr), mx.array(response_prefix_lengths), mx.array(
lengths
)
if not train:
break
def iterate_batches( def iterate_batches(
dataset, dataset,
tokenizer, tokenizer,
batch_size, batch_size,
max_seq_length, max_seq_length,
train=False, train=False,
response_generation_tokens=None,
): ):
# Sort by length: # Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
@ -227,6 +124,10 @@ def iterate_batches(
indices = np.random.permutation(len(batch_idx)) indices = np.random.permutation(len(batch_idx))
for i in indices: for i in indices:
batch = [dataset[j] for j in batch_idx[i]] batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:
print( print(
@ -249,8 +150,7 @@ def iterate_batches(
truncated_length # Update lengths to match truncated lengths truncated_length # Update lengths to match truncated lengths
) )
batch = mx.array(batch_arr) batch = mx.array(batch_arr)
yield batch, mx.array(list(zip(offsets, lengths)))
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train: if not train:
break break
@ -265,7 +165,6 @@ def evaluate(
max_seq_length=2048, max_seq_length=2048,
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_batches,
response_generation_tokens: Optional[List[int]] = None,
): ):
all_losses = mx.array(0.0) all_losses = mx.array(0.0)
ntokens = mx.array(0) ntokens = mx.array(0)
@ -279,7 +178,6 @@ def evaluate(
tokenizer=tokenizer, tokenizer=tokenizer,
batch_size=batch_size, batch_size=batch_size,
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
response_generation_tokens=response_generation_tokens,
), ),
): ):
losses, toks = loss(model, *batch) losses, toks = loss(model, *batch)
@ -355,7 +253,6 @@ def train(
batch_size=args.batch_size, batch_size=args.batch_size,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
train=True, train=True,
response_generation_tokens=args.response_generation_tokens,
), ),
): ):
# Report validation loss if needed, the first validation loss # Report validation loss if needed, the first validation loss
@ -371,7 +268,6 @@ def train(
num_batches=args.val_batches, num_batches=args.val_batches,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
response_generation_tokens=args.response_generation_tokens,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - stop
if rank == 0: if rank == 0: