mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix encoding with special tokens + chat template (#1189)
This commit is contained in:
parent
3a58c36109
commit
c4833a2f55
@ -58,7 +58,7 @@ prompt = "Write a story about Einstein"
|
|||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||||
@ -115,7 +115,7 @@ prompt = "Write a story about Einstein"
|
|||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||||
|
@ -110,29 +110,17 @@ def main():
|
|||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
tokenizer.chat_template = tokenizer.default_chat_template
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
|
|
||||||
if not args.ignore_chat_template and (
|
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||||
hasattr(tokenizer, "apply_chat_template")
|
|
||||||
and tokenizer.chat_template is not None
|
|
||||||
):
|
|
||||||
messages = [{"role": "user", "content": args.prompt}]
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, add_generation_prompt=False, continue_final_message=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Treat the prompt as a prefix assuming that the suffix will be
|
|
||||||
# provided at generation time.
|
|
||||||
test_prompt = tokenizer.apply_chat_template(
|
|
||||||
[{"role": "user", "content": "<query>"}],
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
|
|
||||||
prompt = prompt[:-n]
|
|
||||||
else:
|
else:
|
||||||
prompt = args.prompt
|
prompt = tokenizer.encode(args.prompt)
|
||||||
|
|
||||||
cache = make_prompt_cache(model, args.max_kv_size)
|
cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
y = mx.array(tokenizer.encode(prompt))
|
y = mx.array(prompt)
|
||||||
|
|
||||||
# Process the prompt
|
# Process the prompt
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
@ -72,9 +72,7 @@ def main():
|
|||||||
if query == "q":
|
if query == "q":
|
||||||
break
|
break
|
||||||
messages = [{"role": "user", "content": query}]
|
messages = [{"role": "user", "content": query}]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
for response in stream_generate(
|
for response in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
# Adapted from a PyTorch implementation by David Grangier
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Adapted from a PyTorch implementation by David Grangier
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
@ -6,7 +10,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import lm_eval
|
import lm_eval
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -277,19 +281,19 @@ class MLXLM(LM):
|
|||||||
assert "until" in keys
|
assert "until" in keys
|
||||||
untils = [x["until"] for x in options]
|
untils = [x["until"] for x in options]
|
||||||
completions = []
|
completions = []
|
||||||
|
|
||||||
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
||||||
if (
|
if self._tokenizer.chat_template is not None:
|
||||||
hasattr(self._tokenizer, "apply_chat_template")
|
|
||||||
and self._tokenizer.chat_template is not None
|
|
||||||
):
|
|
||||||
messages = [{"role": "user", "content": context}]
|
messages = [{"role": "user", "content": context}]
|
||||||
context = self._tokenizer.apply_chat_template(
|
context = self._tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
context = self._tokenizer.encode(context)
|
||||||
|
|
||||||
max_tokens = min(
|
max_tokens = min(
|
||||||
self._max_tokens,
|
self._max_tokens,
|
||||||
self._tokenizer.model_max_length - len(self._tokenizer.encode(context)),
|
self._tokenizer.model_max_length - len(context),
|
||||||
)
|
)
|
||||||
text = ""
|
text = ""
|
||||||
for response in stream_generate(
|
for response in stream_generate(
|
||||||
@ -321,6 +325,12 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
|
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit",
|
||||||
|
default=1.0,
|
||||||
|
help="Limit the number of examples per task.",
|
||||||
|
type=float,
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -338,10 +348,12 @@ def main():
|
|||||||
model=lm,
|
model=lm,
|
||||||
tasks=args.tasks,
|
tasks=args.tasks,
|
||||||
num_fewshot=args.num_shots,
|
num_fewshot=args.num_shots,
|
||||||
|
limit=args.limit,
|
||||||
random_seed=args.seed,
|
random_seed=args.seed,
|
||||||
numpy_random_seed=args.seed,
|
numpy_random_seed=args.seed,
|
||||||
torch_random_seed=args.seed,
|
torch_random_seed=args.seed,
|
||||||
fewshot_random_seed=args.seed,
|
fewshot_random_seed=args.seed,
|
||||||
|
apply_chat_template=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_name = args.model.replace("/", "_")
|
model_name = args.model.replace("/", "_")
|
||||||
|
@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model)
|
|||||||
# User turn
|
# User turn
|
||||||
prompt = "Hi my name is <Name>."
|
prompt = "Hi my name is <Name>."
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assistant response
|
# Assistant response
|
||||||
response = generate(
|
response = generate(
|
||||||
@ -32,9 +30,7 @@ response = generate(
|
|||||||
# User turn
|
# User turn
|
||||||
prompt = "What's my name?"
|
prompt = "What's my name?"
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assistant response
|
# Assistant response
|
||||||
response = generate(
|
response = generate(
|
||||||
|
@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}]
|
|||||||
|
|
||||||
# Transform the prompt into the chat template
|
# Transform the prompt into the chat template
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
conversation=conversation, tokenize=False, add_generation_prompt=True
|
conversation=conversation, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Specify the maximum number of tokens
|
# Specify the maximum number of tokens
|
||||||
|
@ -190,10 +190,7 @@ def main():
|
|||||||
|
|
||||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||||
if not args.ignore_chat_template and (
|
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||||
hasattr(tokenizer, "apply_chat_template")
|
|
||||||
and tokenizer.chat_template is not None
|
|
||||||
):
|
|
||||||
if args.system_prompt is not None:
|
if args.system_prompt is not None:
|
||||||
messages = [{"role": "system", "content": args.system_prompt}]
|
messages = [{"role": "system", "content": args.system_prompt}]
|
||||||
else:
|
else:
|
||||||
@ -214,6 +211,10 @@ def main():
|
|||||||
)
|
)
|
||||||
prompt = prompt[test_prompt.index("<query>") :]
|
prompt = prompt[test_prompt.index("<query>") :]
|
||||||
|
|
||||||
|
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
else:
|
||||||
|
prompt = tokenizer.encode(prompt)
|
||||||
|
|
||||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -271,6 +272,7 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
parser = build_parser()
|
parser = build_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
config = args.config
|
config = args.config
|
||||||
|
@ -590,14 +590,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
# Determine response type
|
# Determine response type
|
||||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
||||||
if (
|
if self.tokenizer.chat_template:
|
||||||
hasattr(self.tokenizer, "apply_chat_template")
|
|
||||||
and self.tokenizer.chat_template
|
|
||||||
):
|
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
body["messages"],
|
body["messages"],
|
||||||
body.get("tools", None),
|
body.get("tools", None),
|
||||||
tokenize=True,
|
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -10,41 +10,47 @@ class Dataset:
|
|||||||
Light-weight wrapper to hold a dataset.
|
Light-weight wrapper to hold a dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
|
def __init__(
|
||||||
self._text_key = text_key
|
self,
|
||||||
self._data = data
|
data: List[Dict[str, str]],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
text_key: str = "text",
|
||||||
|
):
|
||||||
|
self._data = [tokenizer.encode(d[text_key]) for d in data]
|
||||||
|
for d in self._data:
|
||||||
|
if d[-1] != tokenizer.eos_token_id:
|
||||||
|
d.append(tokenizer.eos_token_id)
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx][self._text_key]
|
return self._data[idx]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self._data is None:
|
|
||||||
return 0
|
|
||||||
return len(self._data)
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
class ChatDataset(Dataset):
|
class ChatDataset:
|
||||||
"""
|
"""
|
||||||
A dataset for chat data in the format of {"messages": [...]}
|
A dataset for chat data in the format of {"messages": [...]}
|
||||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
||||||
super().__init__(data)
|
self._data = [
|
||||||
self._tokenizer = tokenizer
|
tokenizer.apply_chat_template(
|
||||||
|
d["messages"],
|
||||||
|
tools=d.get("tools", None),
|
||||||
|
)
|
||||||
|
for d in data
|
||||||
|
]
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
messages = self._data[idx]["messages"]
|
return self._data[idx]
|
||||||
text = self._tokenizer.apply_chat_template(
|
|
||||||
messages,
|
def __len__(self):
|
||||||
tools=self._data[idx].get("tools", None),
|
return len(self._data)
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionsDataset(Dataset):
|
class CompletionsDataset:
|
||||||
"""
|
"""
|
||||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||||
or using user-provided keys for prompt and completion values
|
or using user-provided keys for prompt and completion values
|
||||||
@ -58,25 +64,24 @@ class CompletionsDataset(Dataset):
|
|||||||
prompt_key: str = "prompt",
|
prompt_key: str = "prompt",
|
||||||
completion_key: str = "completion",
|
completion_key: str = "completion",
|
||||||
):
|
):
|
||||||
super().__init__(data)
|
self._data = [
|
||||||
self._tokenizer = tokenizer
|
tokenizer.apply_chat_template(
|
||||||
self._prompt_key = prompt_key
|
[
|
||||||
self._completion_key = completion_key
|
{"role": "user", "content": d[prompt_key]},
|
||||||
|
{"role": "assistant", "content": d[completion_key]},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for d in data
|
||||||
|
]
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
data = self._data[idx]
|
return self._data[idx]
|
||||||
text = self._tokenizer.apply_chat_template(
|
|
||||||
[
|
def __len__(self):
|
||||||
{"role": "user", "content": data[self._prompt_key]},
|
return len(self._data)
|
||||||
{"role": "assistant", "content": data[self._completion_key]},
|
|
||||||
],
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
def create_dataset(data, tokenizer: PreTrainedTokenizer):
|
||||||
sample = data[0]
|
sample = data[0]
|
||||||
|
|
||||||
if "messages" in sample:
|
if "messages" in sample:
|
||||||
@ -84,7 +89,7 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
|||||||
elif "prompt" in sample and "completion" in sample:
|
elif "prompt" in sample and "completion" in sample:
|
||||||
return CompletionsDataset(data, tokenizer)
|
return CompletionsDataset(data, tokenizer)
|
||||||
elif "text" in sample:
|
elif "text" in sample:
|
||||||
return Dataset(data)
|
return Dataset(data, tokenizer)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported data format, check the supported formats here:\n"
|
"Unsupported data format, check the supported formats here:\n"
|
||||||
@ -143,7 +148,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
if prompt_feature and completion_feature:
|
if prompt_feature and completion_feature:
|
||||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||||
elif text_feature:
|
elif text_feature:
|
||||||
return Dataset(train_ds, text_key=text_feature)
|
return Dataset(train_ds, tokenizer, text_key=text_feature)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Specify either a prompt and completion feature or a text "
|
"Specify either a prompt and completion feature or a text "
|
||||||
@ -166,7 +171,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
|
|
||||||
|
|
||||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
if getattr(args, "hf_dataset", None) is not None:
|
if getattr(args, "hf_dataset", False):
|
||||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||||
else:
|
else:
|
||||||
data_path = Path(args.data)
|
data_path = Path(args.data)
|
||||||
|
@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
while True:
|
while True:
|
||||||
indices = np.random.permutation(len(batch_idx))
|
indices = np.random.permutation(len(batch_idx))
|
||||||
for i in indices:
|
for i in indices:
|
||||||
# Encode batch
|
batch = [dataset[j] for j in batch_idx[i]]
|
||||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
|
||||||
for b in batch:
|
|
||||||
if b[-1] != tokenizer.eos_token_id:
|
|
||||||
b.append(tokenizer.eos_token_id)
|
|
||||||
|
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
|
|
||||||
if max(lengths) > max_seq_length:
|
if max(lengths) > max_seq_length:
|
||||||
print(
|
print(
|
||||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||||
|
@ -353,9 +353,13 @@ def stream_generate(
|
|||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
if not isinstance(prompt, mx.array):
|
if not isinstance(prompt, mx.array):
|
||||||
prompt = mx.array(
|
if isinstance(prompt, str):
|
||||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
# Try to infer if special tokens are needed
|
||||||
)
|
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||||
|
tokenizer.bos_token
|
||||||
|
)
|
||||||
|
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||||
|
prompt = mx.array(prompt)
|
||||||
|
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
@ -401,7 +405,7 @@ def stream_generate(
|
|||||||
def generate(
|
def generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: Union[str, List[int]],
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -412,7 +416,7 @@ def generate(
|
|||||||
Args:
|
Args:
|
||||||
model (nn.Module): The language model.
|
model (nn.Module): The language model.
|
||||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
prompt (str): The string prompt.
|
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||||
verbose (bool): If ``True``, print tokens and timing information.
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||||
@ -425,7 +429,6 @@ def generate(
|
|||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", prompt)
|
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||||
@ -654,10 +657,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|||||||
|
|
||||||
prompt="hello"
|
prompt="hello"
|
||||||
|
|
||||||
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
|
if tokenizer.chat_template is not None:
|
||||||
messages = [{{"role": "user", "content": prompt}}]
|
messages = [{{"role": "user", "content": prompt}}]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||||
|
@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase):
|
|||||||
data = {"text": "This is an example for the model."}
|
data = {"text": "This is an example for the model."}
|
||||||
self.save_data(4 * [data])
|
self.save_data(4 * [data])
|
||||||
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||||
train, valid, test = datasets.load_dataset(args, None)
|
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||||
|
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||||
self.assertEqual(len(train), 4)
|
self.assertEqual(len(train), 4)
|
||||||
self.assertEqual(len(valid), 4)
|
self.assertEqual(len(valid), 4)
|
||||||
self.assertEqual(len(test), 0)
|
self.assertEqual(len(test), 0)
|
||||||
@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase):
|
|||||||
"name": "billsum",
|
"name": "billsum",
|
||||||
"prompt_feature": "text",
|
"prompt_feature": "text",
|
||||||
"completion_feature": "summary",
|
"completion_feature": "summary",
|
||||||
|
"train_split": "train[:2%]",
|
||||||
|
"valid_split": "train[-2%:]",
|
||||||
},
|
},
|
||||||
test=False,
|
test=False,
|
||||||
train=True,
|
train=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user