mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix encoding with special tokens + chat template (#1189)
This commit is contained in:
parent
3a58c36109
commit
c4833a2f55
@ -58,7 +58,7 @@ prompt = "Write a story about Einstein"
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
@ -115,7 +115,7 @@ prompt = "Write a story about Einstein"
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
|
@ -110,29 +110,17 @@ def main():
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=False, continue_final_message=True
|
||||
)
|
||||
|
||||
# Treat the prompt as a prefix assuming that the suffix will be
|
||||
# provided at generation time.
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
|
||||
prompt = prompt[:-n]
|
||||
else:
|
||||
prompt = args.prompt
|
||||
prompt = tokenizer.encode(args.prompt)
|
||||
|
||||
cache = make_prompt_cache(model, args.max_kv_size)
|
||||
y = mx.array(tokenizer.encode(prompt))
|
||||
y = mx.array(prompt)
|
||||
|
||||
# Process the prompt
|
||||
start = time.time()
|
||||
|
@ -72,9 +72,7 @@ def main():
|
||||
if query == "q":
|
||||
break
|
||||
messages = [{"role": "user", "content": query}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
for response in stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
|
@ -1,4 +1,8 @@
|
||||
# Adapted from a PyTorch implementation by David Grangier
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Adapted from a PyTorch implementation by David Grangier
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
@ -6,7 +10,7 @@ import logging
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import lm_eval
|
||||
import mlx.core as mx
|
||||
@ -277,19 +281,19 @@ class MLXLM(LM):
|
||||
assert "until" in keys
|
||||
untils = [x["until"] for x in options]
|
||||
completions = []
|
||||
|
||||
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
||||
if (
|
||||
hasattr(self._tokenizer, "apply_chat_template")
|
||||
and self._tokenizer.chat_template is not None
|
||||
):
|
||||
if self._tokenizer.chat_template is not None:
|
||||
messages = [{"role": "user", "content": context}]
|
||||
context = self._tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
context = self._tokenizer.encode(context)
|
||||
|
||||
max_tokens = min(
|
||||
self._max_tokens,
|
||||
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)),
|
||||
self._tokenizer.model_max_length - len(context),
|
||||
)
|
||||
text = ""
|
||||
for response in stream_generate(
|
||||
@ -321,6 +325,12 @@ def main():
|
||||
type=int,
|
||||
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
default=1.0,
|
||||
help="Limit the number of examples per task.",
|
||||
type=float,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -338,10 +348,12 @@ def main():
|
||||
model=lm,
|
||||
tasks=args.tasks,
|
||||
num_fewshot=args.num_shots,
|
||||
limit=args.limit,
|
||||
random_seed=args.seed,
|
||||
numpy_random_seed=args.seed,
|
||||
torch_random_seed=args.seed,
|
||||
fewshot_random_seed=args.seed,
|
||||
apply_chat_template=True,
|
||||
)
|
||||
|
||||
model_name = args.model.replace("/", "_")
|
||||
|
@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model)
|
||||
# User turn
|
||||
prompt = "Hi my name is <Name>."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
@ -32,9 +30,7 @@ response = generate(
|
||||
# User turn
|
||||
prompt = "What's my name?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
|
@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Transform the prompt into the chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation, tokenize=False, add_generation_prompt=True
|
||||
conversation=conversation, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Specify the maximum number of tokens
|
||||
|
@ -190,10 +190,7 @@ def main():
|
||||
|
||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||
if args.system_prompt is not None:
|
||||
messages = [{"role": "system", "content": args.system_prompt}]
|
||||
else:
|
||||
@ -214,6 +211,10 @@ def main():
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
else:
|
||||
prompt = tokenizer.encode(prompt)
|
||||
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
model,
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
@ -271,6 +272,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
config = args.config
|
||||
|
@ -590,14 +590,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# Determine response type
|
||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
||||
if (
|
||||
hasattr(self.tokenizer, "apply_chat_template")
|
||||
and self.tokenizer.chat_template
|
||||
):
|
||||
if self.tokenizer.chat_template:
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
body["messages"],
|
||||
body.get("tools", None),
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
|
@ -10,41 +10,47 @@ class Dataset:
|
||||
Light-weight wrapper to hold a dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
|
||||
self._text_key = text_key
|
||||
self._data = data
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
text_key: str = "text",
|
||||
):
|
||||
self._data = [tokenizer.encode(d[text_key]) for d in data]
|
||||
for d in self._data:
|
||||
if d[-1] != tokenizer.eos_token_id:
|
||||
d.append(tokenizer.eos_token_id)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx][self._text_key]
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ChatDataset(Dataset):
|
||||
class ChatDataset:
|
||||
"""
|
||||
A dataset for chat data in the format of {"messages": [...]}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(data)
|
||||
self._tokenizer = tokenizer
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
d["messages"],
|
||||
tools=d.get("tools", None),
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
messages = self._data[idx]["messages"]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=self._data[idx].get("tools", None),
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class CompletionsDataset(Dataset):
|
||||
class CompletionsDataset:
|
||||
"""
|
||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||
or using user-provided keys for prompt and completion values
|
||||
@ -58,25 +64,24 @@ class CompletionsDataset(Dataset):
|
||||
prompt_key: str = "prompt",
|
||||
completion_key: str = "completion",
|
||||
):
|
||||
super().__init__(data)
|
||||
self._tokenizer = tokenizer
|
||||
self._prompt_key = prompt_key
|
||||
self._completion_key = completion_key
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[completion_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
data = self._data[idx]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": data[self._prompt_key]},
|
||||
{"role": "assistant", "content": data[self._completion_key]},
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
||||
def create_dataset(data, tokenizer: PreTrainedTokenizer):
|
||||
sample = data[0]
|
||||
|
||||
if "messages" in sample:
|
||||
@ -84,7 +89,7 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
||||
elif "prompt" in sample and "completion" in sample:
|
||||
return CompletionsDataset(data, tokenizer)
|
||||
elif "text" in sample:
|
||||
return Dataset(data)
|
||||
return Dataset(data, tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
@ -143,7 +148,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if prompt_feature and completion_feature:
|
||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||
elif text_feature:
|
||||
return Dataset(train_ds, text_key=text_feature)
|
||||
return Dataset(train_ds, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Specify either a prompt and completion feature or a text "
|
||||
@ -166,7 +171,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
|
||||
|
||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if getattr(args, "hf_dataset", None) is not None:
|
||||
if getattr(args, "hf_dataset", False):
|
||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||
else:
|
||||
data_path = Path(args.data)
|
||||
|
@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
# Encode batch
|
||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||
for b in batch:
|
||||
if b[-1] != tokenizer.eos_token_id:
|
||||
b.append(tokenizer.eos_token_id)
|
||||
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
|
@ -353,9 +353,13 @@ def stream_generate(
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
if not isinstance(prompt, mx.array):
|
||||
prompt = mx.array(
|
||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||
if isinstance(prompt, str):
|
||||
# Try to infer if special tokens are needed
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
@ -401,7 +405,7 @@ def stream_generate(
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
prompt: Union[str, List[int]],
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
@ -412,7 +416,7 @@ def generate(
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (str): The string prompt.
|
||||
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||
@ -425,7 +429,6 @@ def generate(
|
||||
)
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
text = ""
|
||||
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||
@ -654,10 +657,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
|
||||
prompt="hello"
|
||||
|
||||
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
|
||||
if tokenizer.chat_template is not None:
|
||||
messages = [{{"role": "user", "content": prompt}}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
|
@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase):
|
||||
data = {"text": "This is an example for the model."}
|
||||
self.save_data(4 * [data])
|
||||
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||
train, valid, test = datasets.load_dataset(args, None)
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(len(train), 4)
|
||||
self.assertEqual(len(valid), 4)
|
||||
self.assertEqual(len(test), 0)
|
||||
@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase):
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
"train_split": "train[:2%]",
|
||||
"valid_split": "train[-2%:]",
|
||||
},
|
||||
test=False,
|
||||
train=True,
|
||||
|
Loading…
Reference in New Issue
Block a user