diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 3d99894c..2abea970 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -41,6 +41,28 @@ class ChatDataset(Dataset): return text +class ToolsDataset(Dataset): + """ + A dataset for tools data in the format of {"messages": [...],"tools":[...]} + https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples + """ + + def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): + super().__init__(data) + self._tokenizer = tokenizer + + def __getitem__(self, idx: int): + messages = self._data[idx]["messages"] + tools = self._data[idx]["tools"] + text = self._tokenizer.apply_chat_template( + messages, + tools=tools, + tokenize=False, + add_generation_prompt=True + ) + return text + + class CompletionsDataset(Dataset): """ A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} @@ -80,7 +102,10 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): with open(path, "r") as fid: data = [json.loads(l) for l in fid] if "messages" in data[0]: - return ChatDataset(data, tokenizer) + if "tools" in data[0]: + return ToolsDataset(data, tokenizer) + else: + return ChatDataset(data, tokenizer) elif "prompt" in data[0] and "completion" in data[0]: return CompletionsDataset(data, tokenizer) elif "text" in data[0]: