fix encoding with special tokens + chat template (#1189)

This commit is contained in:
Awni Hannun 2025-01-03 10:50:59 -08:00 committed by GitHub
parent 3a58c36109
commit c4833a2f55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 95 additions and 97 deletions

View File

@ -58,7 +58,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
text = generate(model, tokenizer, prompt=prompt, verbose=True)
@ -115,7 +115,7 @@ prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):

View File

@ -110,29 +110,17 @@ def main():
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
if not args.ignore_chat_template and tokenizer.chat_template is not None:
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=False, continue_final_message=True
)
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else:
prompt = args.prompt
prompt = tokenizer.encode(args.prompt)
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
y = mx.array(prompt)
# Process the prompt
start = time.time()

View File

@ -72,9 +72,7 @@ def main():
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model,
tokenizer,

View File

@ -1,4 +1,8 @@
# Adapted from a PyTorch implementation by David Grangier
# Copyright © 2024 Apple Inc.
"""
Adapted from a PyTorch implementation by David Grangier
"""
import argparse
import json
@ -6,7 +10,7 @@ import logging
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import lm_eval
import mlx.core as mx
@ -277,19 +281,19 @@ class MLXLM(LM):
assert "until" in keys
untils = [x["until"] for x in options]
completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
if (
hasattr(self._tokenizer, "apply_chat_template")
and self._tokenizer.chat_template is not None
):
if self._tokenizer.chat_template is not None:
messages = [{"role": "user", "content": context}]
context = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
else:
context = self._tokenizer.encode(context)
max_tokens = min(
self._max_tokens,
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)),
self._tokenizer.model_max_length - len(context),
)
text = ""
for response in stream_generate(
@ -321,6 +325,12 @@ def main():
type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
)
parser.add_argument(
"--limit",
default=1.0,
help="Limit the number of examples per task.",
type=float,
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
args = parser.parse_args()
@ -338,10 +348,12 @@ def main():
model=lm,
tasks=args.tasks,
num_fewshot=args.num_shots,
limit=args.limit,
random_seed=args.seed,
numpy_random_seed=args.seed,
torch_random_seed=args.seed,
fewshot_random_seed=args.seed,
apply_chat_template=True,
)
model_name = args.model.replace("/", "_")

View File

@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model)
# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Assistant response
response = generate(
@ -32,9 +30,7 @@ response = generate(
# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Assistant response
response = generate(

View File

@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True
conversation=conversation, add_generation_prompt=True
)
# Specify the maximum number of tokens

View File

@ -190,10 +190,7 @@ def main():
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
if not args.ignore_chat_template and tokenizer.chat_template is not None:
if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}]
else:
@ -214,6 +211,10 @@ def main():
)
prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
prompt = tokenizer.encode(prompt)
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate(
model,

View File

@ -2,6 +2,7 @@
import argparse
import math
import os
import re
import types
from pathlib import Path
@ -271,6 +272,7 @@ def run(args, training_callback: TrainingCallback = None):
def main():
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = build_parser()
args = parser.parse_args()
config = args.config

View File

@ -590,14 +590,10 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
):
if self.tokenizer.chat_template:
prompt = self.tokenizer.apply_chat_template(
body["messages"],
body.get("tools", None),
tokenize=True,
add_generation_prompt=True,
)
else:

View File

@ -10,41 +10,47 @@ class Dataset:
Light-weight wrapper to hold a dataset.
"""
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
self._text_key = text_key
self._data = data
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
text_key: str = "text",
):
self._data = [tokenizer.encode(d[text_key]) for d in data]
for d in self._data:
if d[-1] != tokenizer.eos_token_id:
d.append(tokenizer.eos_token_id)
def __getitem__(self, idx: int):
return self._data[idx][self._text_key]
return self._data[idx]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
class ChatDataset(Dataset):
class ChatDataset:
"""
A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data)
self._tokenizer = tokenizer
self._data = [
tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages,
tools=self._data[idx].get("tools", None),
tokenize=False,
add_generation_prompt=True,
)
return text
return self._data[idx]
def __len__(self):
return len(self._data)
class CompletionsDataset(Dataset):
class CompletionsDataset:
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
@ -58,25 +64,24 @@ class CompletionsDataset(Dataset):
prompt_key: str = "prompt",
completion_key: str = "completion",
):
super().__init__(data)
self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
self._data = [
tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
return self._data[idx]
def __len__(self):
return len(self._data)
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
def create_dataset(data, tokenizer: PreTrainedTokenizer):
sample = data[0]
if "messages" in sample:
@ -84,7 +89,7 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer)
elif "text" in sample:
return Dataset(data)
return Dataset(data, tokenizer)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
@ -143,7 +148,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
return Dataset(train_ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
@ -166,7 +171,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
if getattr(args, "hf_dataset", False):
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)

View File

@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
batch = [dataset[j] for j in batch_idx[i]]
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "

View File

@ -353,9 +353,13 @@ def stream_generate(
tokenizer = TokenizerWrapper(tokenizer)
if not isinstance(prompt, mx.array):
prompt = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
if isinstance(prompt, str):
# Try to infer if special tokens are needed
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer
@ -401,7 +405,7 @@ def stream_generate(
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
prompt: Union[str, List[int]],
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
@ -412,7 +416,7 @@ def generate(
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.
@ -425,7 +429,6 @@ def generate(
)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs):
@ -654,10 +657,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
if tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)

View File

@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase):
data = {"text": "This is an example for the model."}
self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
train, valid, test = datasets.load_dataset(args, None)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0)
@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase):
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
},
test=False,
train=True,