mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Merge branch 'ml-explore:main' into adding-support-for-mamba2
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@@ -10,41 +10,47 @@ class Dataset:
|
||||
Light-weight wrapper to hold a dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
|
||||
self._text_key = text_key
|
||||
self._data = data
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
text_key: str = "text",
|
||||
):
|
||||
self._data = [tokenizer.encode(d[text_key]) for d in data]
|
||||
for d in self._data:
|
||||
if d[-1] != tokenizer.eos_token_id:
|
||||
d.append(tokenizer.eos_token_id)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx][self._text_key]
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ChatDataset(Dataset):
|
||||
class ChatDataset:
|
||||
"""
|
||||
A dataset for chat data in the format of {"messages": [...]}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(data)
|
||||
self._tokenizer = tokenizer
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
d["messages"],
|
||||
tools=d.get("tools", None),
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
messages = self._data[idx]["messages"]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=self._data[idx].get("tools", None),
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class CompletionsDataset(Dataset):
|
||||
class CompletionsDataset:
|
||||
"""
|
||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||
or using user-provided keys for prompt and completion values
|
||||
@@ -55,36 +61,41 @@ class CompletionsDataset(Dataset):
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
completion_key: str = "completion",
|
||||
prompt_key: str,
|
||||
completion_key: str,
|
||||
):
|
||||
super().__init__(data)
|
||||
self._tokenizer = tokenizer
|
||||
self._prompt_key = prompt_key
|
||||
self._completion_key = completion_key
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[completion_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
data = self._data[idx]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": data[self._prompt_key]},
|
||||
{"role": "assistant", "content": data[self._completion_key]},
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
||||
def create_dataset(
|
||||
data,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
):
|
||||
prompt_feature = prompt_feature or "prompt"
|
||||
completion_feature = completion_feature or "completion"
|
||||
sample = data[0]
|
||||
|
||||
if "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif "prompt" in sample and "completion" in sample:
|
||||
return CompletionsDataset(data, tokenizer)
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
elif "text" in sample:
|
||||
return Dataset(data)
|
||||
return Dataset(data, tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
@@ -92,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
||||
)
|
||||
|
||||
|
||||
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
|
||||
def load_local_dataset(
|
||||
data_path: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
):
|
||||
def load_subset(path):
|
||||
if not path.exists():
|
||||
return []
|
||||
with open(path, "r") as fid:
|
||||
data = [json.loads(l) for l in fid]
|
||||
return create_dataset(data, tokenizer)
|
||||
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||
return train, valid, test
|
||||
|
||||
|
||||
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
|
||||
def load_hf_dataset(
|
||||
data_id: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
):
|
||||
from datasets import exceptions, load_dataset
|
||||
|
||||
try:
|
||||
@@ -114,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
|
||||
names = ("train", "valid", "test")
|
||||
|
||||
train, valid, test = [
|
||||
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
|
||||
(
|
||||
create_dataset(
|
||||
dataset[n], tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
if n in dataset.keys()
|
||||
else []
|
||||
)
|
||||
for n in names
|
||||
]
|
||||
|
||||
@@ -143,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if prompt_feature and completion_feature:
|
||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||
elif text_feature:
|
||||
return Dataset(train_ds, text_key=text_feature)
|
||||
return Dataset(train_ds, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Specify either a prompt and completion feature or a text "
|
||||
@@ -166,15 +193,22 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
|
||||
|
||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if getattr(args, "hf_dataset", None) is not None:
|
||||
if getattr(args, "hf_dataset", False):
|
||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||
else:
|
||||
data_path = Path(args.data)
|
||||
|
||||
prompt_feature = getattr(args, "prompt_feature", None)
|
||||
completion_feature = getattr(args, "completion_feature", None)
|
||||
if data_path.exists():
|
||||
train, valid, test = load_local_dataset(data_path, tokenizer)
|
||||
train, valid, test = load_local_dataset(
|
||||
data_path, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
else:
|
||||
print(f"Loading Hugging Face dataset {args.data}.")
|
||||
train, valid, test = load_hf_dataset(args.data, tokenizer)
|
||||
train, valid, test = load_hf_dataset(
|
||||
args.data, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
|
||||
@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
# Encode batch
|
||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||
for b in batch:
|
||||
if b[-1] != tokenizer.eos_token_id:
|
||||
b.append(tokenizer.eos_token_id)
|
||||
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
|
||||
Reference in New Issue
Block a user