mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
fix encoding with special tokens + chat template (#1189)
This commit is contained in:
@@ -10,41 +10,47 @@ class Dataset:
|
||||
Light-weight wrapper to hold a dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
|
||||
self._text_key = text_key
|
||||
self._data = data
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
text_key: str = "text",
|
||||
):
|
||||
self._data = [tokenizer.encode(d[text_key]) for d in data]
|
||||
for d in self._data:
|
||||
if d[-1] != tokenizer.eos_token_id:
|
||||
d.append(tokenizer.eos_token_id)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx][self._text_key]
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ChatDataset(Dataset):
|
||||
class ChatDataset:
|
||||
"""
|
||||
A dataset for chat data in the format of {"messages": [...]}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(data)
|
||||
self._tokenizer = tokenizer
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
d["messages"],
|
||||
tools=d.get("tools", None),
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
messages = self._data[idx]["messages"]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=self._data[idx].get("tools", None),
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class CompletionsDataset(Dataset):
|
||||
class CompletionsDataset:
|
||||
"""
|
||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||
or using user-provided keys for prompt and completion values
|
||||
@@ -58,25 +64,24 @@ class CompletionsDataset(Dataset):
|
||||
prompt_key: str = "prompt",
|
||||
completion_key: str = "completion",
|
||||
):
|
||||
super().__init__(data)
|
||||
self._tokenizer = tokenizer
|
||||
self._prompt_key = prompt_key
|
||||
self._completion_key = completion_key
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[completion_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
data = self._data[idx]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": data[self._prompt_key]},
|
||||
{"role": "assistant", "content": data[self._completion_key]},
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
||||
def create_dataset(data, tokenizer: PreTrainedTokenizer):
|
||||
sample = data[0]
|
||||
|
||||
if "messages" in sample:
|
||||
@@ -84,7 +89,7 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
||||
elif "prompt" in sample and "completion" in sample:
|
||||
return CompletionsDataset(data, tokenizer)
|
||||
elif "text" in sample:
|
||||
return Dataset(data)
|
||||
return Dataset(data, tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
@@ -143,7 +148,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if prompt_feature and completion_feature:
|
||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||
elif text_feature:
|
||||
return Dataset(train_ds, text_key=text_feature)
|
||||
return Dataset(train_ds, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Specify either a prompt and completion feature or a text "
|
||||
@@ -166,7 +171,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
|
||||
|
||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if getattr(args, "hf_dataset", None) is not None:
|
||||
if getattr(args, "hf_dataset", False):
|
||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||
else:
|
||||
data_path = Path(args.data)
|
||||
|
@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
# Encode batch
|
||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||
for b in batch:
|
||||
if b[-1] != tokenizer.eos_token_id:
|
||||
b.append(tokenizer.eos_token_id)
|
||||
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
|
Reference in New Issue
Block a user