fix encoding with special tokens + chat template (#1189)

This commit is contained in:
Awni Hannun
2025-01-03 10:50:59 -08:00
committed by GitHub
parent 3a58c36109
commit c4833a2f55
13 changed files with 95 additions and 97 deletions

View File

@@ -10,41 +10,47 @@ class Dataset:
Light-weight wrapper to hold a dataset.
"""
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
self._text_key = text_key
self._data = data
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
text_key: str = "text",
):
self._data = [tokenizer.encode(d[text_key]) for d in data]
for d in self._data:
if d[-1] != tokenizer.eos_token_id:
d.append(tokenizer.eos_token_id)
def __getitem__(self, idx: int):
return self._data[idx][self._text_key]
return self._data[idx]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
class ChatDataset(Dataset):
class ChatDataset:
"""
A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data)
self._tokenizer = tokenizer
self._data = [
tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages,
tools=self._data[idx].get("tools", None),
tokenize=False,
add_generation_prompt=True,
)
return text
return self._data[idx]
def __len__(self):
return len(self._data)
class CompletionsDataset(Dataset):
class CompletionsDataset:
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
@@ -58,25 +64,24 @@ class CompletionsDataset(Dataset):
prompt_key: str = "prompt",
completion_key: str = "completion",
):
super().__init__(data)
self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
self._data = [
tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
return self._data[idx]
def __len__(self):
return len(self._data)
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
def create_dataset(data, tokenizer: PreTrainedTokenizer):
sample = data[0]
if "messages" in sample:
@@ -84,7 +89,7 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer)
elif "text" in sample:
return Dataset(data)
return Dataset(data, tokenizer)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
@@ -143,7 +148,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
return Dataset(train_ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
@@ -166,7 +171,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
if getattr(args, "hf_dataset", False):
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)

View File

@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
batch = [dataset[j] for j in batch_idx[i]]
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "