LoRA: support fine-tuning tools datasets

This commit is contained in:
madroid 2024-09-20 10:58:11 +08:00
parent f530f56df2
commit bfd4ba2347

View File

@ -41,6 +41,28 @@ class ChatDataset(Dataset):
return text
class ToolsDataset(Dataset):
"""
A dataset for tools data in the format of {"messages": [...],"tools":[...]}
https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples
"""
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data)
self._tokenizer = tokenizer
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
tools = self._data[idx]["tools"]
text = self._tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
add_generation_prompt=True
)
return text
class CompletionsDataset(Dataset):
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
@ -80,7 +102,10 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
if "messages" in data[0]:
return ChatDataset(data, tokenizer)
if "tools" in data[0]:
return ToolsDataset(data, tokenizer)
else:
return ChatDataset(data, tokenizer)
elif "prompt" in data[0] and "completion" in data[0]:
return CompletionsDataset(data, tokenizer)
elif "text" in data[0]: