fix encoding with special tokens + chat template (#1189)

This commit is contained in:
Awni Hannun 2025-01-03 10:50:59 -08:00 committed by GitHub
parent 3a58c36109
commit c4833a2f55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 95 additions and 97 deletions

View File

@ -58,7 +58,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
text = generate(model, tokenizer, prompt=prompt, verbose=True) text = generate(model, tokenizer, prompt=prompt, verbose=True)
@ -115,7 +115,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
for response in stream_generate(model, tokenizer, prompt, max_tokens=512): for response in stream_generate(model, tokenizer, prompt, max_tokens=512):

View File

@ -110,29 +110,17 @@ def main():
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and ( if not args.ignore_chat_template and tokenizer.chat_template is not None:
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}] messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=False, continue_final_message=True
) )
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else: else:
prompt = args.prompt prompt = tokenizer.encode(args.prompt)
cache = make_prompt_cache(model, args.max_kv_size) cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt)) y = mx.array(prompt)
# Process the prompt # Process the prompt
start = time.time() start = time.time()

View File

@ -72,9 +72,7 @@ def main():
if query == "q": if query == "q":
break break
messages = [{"role": "user", "content": query}] messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate( for response in stream_generate(
model, model,
tokenizer, tokenizer,

View File

@ -1,4 +1,8 @@
# Adapted from a PyTorch implementation by David Grangier # Copyright © 2024 Apple Inc.
"""
Adapted from a PyTorch implementation by David Grangier
"""
import argparse import argparse
import json import json
@ -6,7 +10,7 @@ import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union
import lm_eval import lm_eval
import mlx.core as mx import mlx.core as mx
@ -277,19 +281,19 @@ class MLXLM(LM):
assert "until" in keys assert "until" in keys
untils = [x["until"] for x in options] untils = [x["until"] for x in options]
completions = [] completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)): for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
if ( if self._tokenizer.chat_template is not None:
hasattr(self._tokenizer, "apply_chat_template")
and self._tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": context}] messages = [{"role": "user", "content": context}]
context = self._tokenizer.apply_chat_template( context = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
else:
context = self._tokenizer.encode(context)
max_tokens = min( max_tokens = min(
self._max_tokens, self._max_tokens,
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)), self._tokenizer.model_max_length - len(context),
) )
text = "" text = ""
for response in stream_generate( for response in stream_generate(
@ -321,6 +325,12 @@ def main():
type=int, type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
) )
parser.add_argument(
"--limit",
default=1.0,
help="Limit the number of examples per task.",
type=float,
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.") parser.add_argument("--seed", type=int, default=123, help="Random seed.")
args = parser.parse_args() args = parser.parse_args()
@ -338,10 +348,12 @@ def main():
model=lm, model=lm,
tasks=args.tasks, tasks=args.tasks,
num_fewshot=args.num_shots, num_fewshot=args.num_shots,
limit=args.limit,
random_seed=args.seed, random_seed=args.seed,
numpy_random_seed=args.seed, numpy_random_seed=args.seed,
torch_random_seed=args.seed, torch_random_seed=args.seed,
fewshot_random_seed=args.seed, fewshot_random_seed=args.seed,
apply_chat_template=True,
) )
model_name = args.model.replace("/", "_") model_name = args.model.replace("/", "_")

View File

@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model)
# User turn # User turn
prompt = "Hi my name is <Name>." prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response # Assistant response
response = generate( response = generate(
@ -32,9 +30,7 @@ response = generate(
# User turn # User turn
prompt = "What's my name?" prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response # Assistant response
response = generate( response = generate(

View File

@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template # Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True conversation=conversation, add_generation_prompt=True
) )
# Specify the maximum number of tokens # Specify the maximum number of tokens

View File

@ -190,10 +190,7 @@ def main():
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt prompt = sys.stdin.read() if prompt == "-" else prompt
if not args.ignore_chat_template and ( if not args.ignore_chat_template and tokenizer.chat_template is not None:
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
if args.system_prompt is not None: if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}] messages = [{"role": "system", "content": args.system_prompt}]
else: else:
@ -214,6 +211,10 @@ def main():
) )
prompt = prompt[test_prompt.index("<query>") :] prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
prompt = tokenizer.encode(prompt)
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate( response = generate(
model, model,

View File

@ -2,6 +2,7 @@
import argparse import argparse
import math import math
import os
import re import re
import types import types
from pathlib import Path from pathlib import Path
@ -271,6 +272,7 @@ def run(args, training_callback: TrainingCallback = None):
def main(): def main():
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = build_parser() parser = build_parser()
args = parser.parse_args() args = parser.parse_args()
config = args.config config = args.config

View File

@ -590,14 +590,10 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type # Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}" self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if ( if self.tokenizer.chat_template:
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
):
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
body["messages"], body["messages"],
body.get("tools", None), body.get("tools", None),
tokenize=True,
add_generation_prompt=True, add_generation_prompt=True,
) )
else: else:

View File

@ -10,41 +10,47 @@ class Dataset:
Light-weight wrapper to hold a dataset. Light-weight wrapper to hold a dataset.
""" """
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"): def __init__(
self._text_key = text_key self,
self._data = data data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
text_key: str = "text",
):
self._data = [tokenizer.encode(d[text_key]) for d in data]
for d in self._data:
if d[-1] != tokenizer.eos_token_id:
d.append(tokenizer.eos_token_id)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx][self._text_key] return self._data[idx]
def __len__(self): def __len__(self):
if self._data is None:
return 0
return len(self._data) return len(self._data)
class ChatDataset(Dataset): class ChatDataset:
""" """
A dataset for chat data in the format of {"messages": [...]} A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data) self._data = [
self._tokenizer = tokenizer tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
messages = self._data[idx]["messages"] return self._data[idx]
text = self._tokenizer.apply_chat_template(
messages, def __len__(self):
tools=self._data[idx].get("tools", None), return len(self._data)
tokenize=False,
add_generation_prompt=True,
)
return text
class CompletionsDataset(Dataset): class CompletionsDataset:
""" """
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values or using user-provided keys for prompt and completion values
@ -58,25 +64,24 @@ class CompletionsDataset(Dataset):
prompt_key: str = "prompt", prompt_key: str = "prompt",
completion_key: str = "completion", completion_key: str = "completion",
): ):
super().__init__(data) self._data = [
self._tokenizer = tokenizer tokenizer.apply_chat_template(
self._prompt_key = prompt_key [
self._completion_key = completion_key {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
data = self._data[idx] return self._data[idx]
text = self._tokenizer.apply_chat_template(
[ def __len__(self):
{"role": "user", "content": data[self._prompt_key]}, return len(self._data)
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
def create_dataset(data, tokenizer: PreTrainedTokenizer = None): def create_dataset(data, tokenizer: PreTrainedTokenizer):
sample = data[0] sample = data[0]
if "messages" in sample: if "messages" in sample:
@ -84,7 +89,7 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
elif "prompt" in sample and "completion" in sample: elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer) return CompletionsDataset(data, tokenizer)
elif "text" in sample: elif "text" in sample:
return Dataset(data) return Dataset(data, tokenizer)
else: else:
raise ValueError( raise ValueError(
"Unsupported data format, check the supported formats here:\n" "Unsupported data format, check the supported formats here:\n"
@ -143,7 +148,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
if prompt_feature and completion_feature: if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature: elif text_feature:
return Dataset(train_ds, text_key=text_feature) return Dataset(train_ds, tokenizer, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
"Specify either a prompt and completion feature or a text " "Specify either a prompt and completion feature or a text "
@ -166,7 +171,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None: if getattr(args, "hf_dataset", False):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)

View File

@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
while True: while True:
indices = np.random.permutation(len(batch_idx)) indices = np.random.permutation(len(batch_idx))
for i in indices: for i in indices:
# Encode batch batch = [dataset[j] for j in batch_idx[i]]
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:
print( print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "

View File

@ -353,9 +353,13 @@ def stream_generate(
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
if not isinstance(prompt, mx.array): if not isinstance(prompt, mx.array):
prompt = mx.array( if isinstance(prompt, str):
prompt if isinstance(prompt, list) else tokenizer.encode(prompt) # Try to infer if special tokens are needed
) add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
@ -401,7 +405,7 @@ def stream_generate(
def generate( def generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: Union[str, List[int]],
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
**kwargs, **kwargs,
@ -412,7 +416,7 @@ def generate(
Args: Args:
model (nn.Module): The language model. model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt. prompt (Union[str, List[int]]): The input prompt string or integer tokens.
verbose (bool): If ``True``, print tokens and timing information. verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``. Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`. kwargs: The remaining options get passed to :func:`stream_generate`.
@ -425,7 +429,6 @@ def generate(
) )
if verbose: if verbose:
print("=" * 10) print("=" * 10)
print("Prompt:", prompt)
text = "" text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs): for response in stream_generate(model, tokenizer, prompt, **kwargs):
@ -654,10 +657,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
prompt="hello" prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: if tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}] messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, add_generation_prompt=True
) )
response = generate(model, tokenizer, prompt=prompt, verbose=True) response = generate(model, tokenizer, prompt=prompt, verbose=True)

View File

@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase):
data = {"text": "This is an example for the model."} data = {"text": "This is an example for the model."}
self.save_data(4 * [data]) self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
train, valid, test = datasets.load_dataset(args, None) tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4) self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4) self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0) self.assertEqual(len(test), 0)
@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase):
"name": "billsum", "name": "billsum",
"prompt_feature": "text", "prompt_feature": "text",
"completion_feature": "summary", "completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
}, },
test=False, test=False,
train=True, train=True,