mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
LoRA: support fine-tuning tools datasets
This commit is contained in:
parent
f530f56df2
commit
bfd4ba2347
@ -41,6 +41,28 @@ class ChatDataset(Dataset):
|
||||
return text
|
||||
|
||||
|
||||
class ToolsDataset(Dataset):
|
||||
"""
|
||||
A dataset for tools data in the format of {"messages": [...],"tools":[...]}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(data)
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
messages = self._data[idx]["messages"]
|
||||
tools = self._data[idx]["tools"]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
class CompletionsDataset(Dataset):
|
||||
"""
|
||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||
@ -80,7 +102,10 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
|
||||
with open(path, "r") as fid:
|
||||
data = [json.loads(l) for l in fid]
|
||||
if "messages" in data[0]:
|
||||
return ChatDataset(data, tokenizer)
|
||||
if "tools" in data[0]:
|
||||
return ToolsDataset(data, tokenizer)
|
||||
else:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif "prompt" in data[0] and "completion" in data[0]:
|
||||
return CompletionsDataset(data, tokenizer)
|
||||
elif "text" in data[0]:
|
||||
|
Loading…
Reference in New Issue
Block a user