mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
LoRA: support tools(function calling) format datasets (#995)
* LoRA: support fine-tuning tools datasets * LoRA: Split small function * LoRA: add tools format to lora docs * LoRA: pre-commit fix * Revert "LoRA: pre-commit fix" This reverts commitb94b7e0fe7
. * Revert "LoRA: Split small function" This reverts commit3f6a5f19fd
. * LoRA: remove ToolsDataset In a JSONL file, not all data is required to include the tools value. * nit in readme * nit in readme * nit in readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
ace2bb5890
commit
7ec2021bb9
@ -160,8 +160,8 @@ For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
|||||||
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||||
loader expects a `test.jsonl` in the data directory.
|
loader expects a `test.jsonl` in the data directory.
|
||||||
|
|
||||||
Currently, `*.jsonl` files support three data formats: `chat`,
|
Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
|
||||||
`completions`, and `text`. Here are three examples of these formats:
|
data formats. Here are examples of these formats:
|
||||||
|
|
||||||
`chat`:
|
`chat`:
|
||||||
|
|
||||||
@ -169,6 +169,58 @@ Currently, `*.jsonl` files support three data formats: `chat`,
|
|||||||
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
|
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`tools`:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>View the expanded single data tool format</summary>
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{ "role": "user", "content": "What is the weather in San Francisco?" },
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_id",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and country, eg. San Francisco, USA"
|
||||||
|
},
|
||||||
|
"format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
|
||||||
|
},
|
||||||
|
"required": ["location", "format"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
`completions`:
|
`completions`:
|
||||||
|
|
||||||
```jsonl
|
```jsonl
|
||||||
@ -215,11 +267,13 @@ hf_dataset:
|
|||||||
- Arguments specified in `config` will be passed as keyword arguments to
|
- Arguments specified in `config` will be passed as keyword arguments to
|
||||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||||
|
|
||||||
In general, for the `chat` and `completions` formats, Hugging Face [chat
|
In general, for the `chat`, `tools` and `completions` formats, Hugging Face
|
||||||
templates](https://huggingface.co/blog/chat-templates) are used. This applies
|
[chat
|
||||||
the model's chat template by default. If the model does not have a chat
|
templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||||
template, then Hugging Face will use a default. For example, the final text in
|
are used. This applies the model's chat template by default. If the model does
|
||||||
the `chat` example above with Hugging Face's default template becomes:
|
not have a chat template, then Hugging Face will use a default. For example,
|
||||||
|
the final text in the `chat` example above with Hugging Face's default template
|
||||||
|
becomes:
|
||||||
|
|
||||||
```text
|
```text
|
||||||
<|im_start|>system
|
<|im_start|>system
|
||||||
|
@ -36,7 +36,10 @@ class ChatDataset(Dataset):
|
|||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
messages = self._data[idx]["messages"]
|
messages = self._data[idx]["messages"]
|
||||||
text = self._tokenizer.apply_chat_template(
|
text = self._tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages,
|
||||||
|
tools=self._data[idx].get("tools", None),
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
@ -93,9 +93,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
# Encode batch
|
# Encode batch
|
||||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||||
for b in batch:
|
for b in batch:
|
||||||
if b[-1] == tokenizer.eos_token_id:
|
if b[-1] != tokenizer.eos_token_id:
|
||||||
print("[WARNING] Example already has an EOS token appended")
|
|
||||||
else:
|
|
||||||
b.append(tokenizer.eos_token_id)
|
b.append(tokenizer.eos_token_id)
|
||||||
|
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
|
Loading…
Reference in New Issue
Block a user