mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Completion only fine-tuning of instruction models with collections of HF datasets (#1103)
- Optional completion only fine-tuning with `--mask-prompt` - Collections of Hugging Face datasets --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
1ced1b00ca
commit
5865899c81
@ -76,6 +76,14 @@ You can specify the output location with `--adapter-path`.
|
|||||||
You can resume fine-tuning with an existing adapter with
|
You can resume fine-tuning with an existing adapter with
|
||||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||||
|
|
||||||
|
#### Prompt Masking
|
||||||
|
|
||||||
|
The default training computes a loss for every token in the sample. You can
|
||||||
|
ignore the prompt and compute loss for just the completion by passing
|
||||||
|
`--mask-prompt`. Note this is only supported for `chat` and `completion`
|
||||||
|
datasets. For `chat` datasets the final message in the message list is
|
||||||
|
considered the completion. See the [dataset section](#Data) for more details.
|
||||||
|
|
||||||
### Evaluate
|
### Evaluate
|
||||||
|
|
||||||
To compute test set perplexity use:
|
To compute test set perplexity use:
|
||||||
@ -290,11 +298,27 @@ hf_dataset:
|
|||||||
|
|
||||||
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
||||||
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
||||||
dataset.
|
dataset. Use `chat_feature` to specify the key for a chat dataset.
|
||||||
|
|
||||||
- To specify the train, valid, or test splits, set the corresponding
|
- To specify the train, valid, or test splits, set the corresponding
|
||||||
`{train,valid,test}_split` argument.
|
`{train,valid,test}_split` argument.
|
||||||
|
|
||||||
|
You can specify a list of Hugging Face datasets with a list of records each
|
||||||
|
with the same structure as above. For example:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
hf_dataset:
|
||||||
|
- name: "Open-Orca/OpenOrca"
|
||||||
|
train_split: "train[:90%]"
|
||||||
|
valid_split: "train[-10%:]"
|
||||||
|
prompt_feature: "question"
|
||||||
|
completion_feature: "response"
|
||||||
|
- name: "trl-lib/ultrafeedback_binarized"
|
||||||
|
train_split: "train[:90%]"
|
||||||
|
valid_split: "train[-10%:]"
|
||||||
|
chat_feature: "chosen"
|
||||||
|
```
|
||||||
|
|
||||||
- Arguments specified in `config` will be passed as keyword arguments to
|
- Arguments specified in `config` will be passed as keyword arguments to
|
||||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||||
|
|
||||||
|
@ -94,6 +94,14 @@ def build_parser():
|
|||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mask-prompt",
|
||||||
|
action="store_true",
|
||||||
|
help="Mask the prompt in the loss when training",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-layers",
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -219,6 +227,7 @@ def train_model(
|
|||||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
train(
|
train(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
|||||||
detokenizer_class,
|
detokenizer_class,
|
||||||
eos_token_ids=eos_token_ids,
|
eos_token_ids=eos_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
|
||||||
|
removed_bos = sequence if sequence[0] != bos else sequence[1:]
|
||||||
|
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
|
import types
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -34,14 +36,24 @@ class ChatDataset:
|
|||||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
def __init__(
|
||||||
self._data = [
|
self,
|
||||||
tokenizer.apply_chat_template(
|
data: List[Dict[str, str]],
|
||||||
d["messages"],
|
tokenizer: PreTrainedTokenizer,
|
||||||
tools=d.get("tools", None),
|
chat_key: str = "messages",
|
||||||
)
|
mask_prompt: bool = False,
|
||||||
for d in data
|
):
|
||||||
]
|
self._data = []
|
||||||
|
for d in data:
|
||||||
|
messages = d[chat_key]
|
||||||
|
tools = d.get("tools", None)
|
||||||
|
tokens = tokenizer.apply_chat_template(messages, tools=tools)
|
||||||
|
if mask_prompt:
|
||||||
|
messages = messages[:-1]
|
||||||
|
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
|
||||||
|
self._data.append((tokens, offset))
|
||||||
|
else:
|
||||||
|
self._data.append(tokens)
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
@ -63,16 +75,36 @@ class CompletionsDataset:
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_key: str,
|
prompt_key: str,
|
||||||
completion_key: str,
|
completion_key: str,
|
||||||
|
mask_prompt: bool,
|
||||||
):
|
):
|
||||||
self._data = [
|
self._data = []
|
||||||
tokenizer.apply_chat_template(
|
for d in data:
|
||||||
|
tokens = tokenizer.apply_chat_template(
|
||||||
[
|
[
|
||||||
{"role": "user", "content": d[prompt_key]},
|
{"role": "user", "content": d[prompt_key]},
|
||||||
{"role": "assistant", "content": d[completion_key]},
|
{"role": "assistant", "content": d[completion_key]},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
for d in data
|
if mask_prompt:
|
||||||
]
|
offset = len(
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": d[prompt_key]}]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._data.append((tokens, offset))
|
||||||
|
else:
|
||||||
|
self._data.append(tokens)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
return self._data[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
class ConcatenatedDataset:
|
||||||
|
def __init__(self, data: List[Any]):
|
||||||
|
self._data = list(itertools.chain(*data))
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
@ -84,18 +116,26 @@ class CompletionsDataset:
|
|||||||
def create_dataset(
|
def create_dataset(
|
||||||
data,
|
data,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
config,
|
||||||
completion_feature: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
prompt_feature = prompt_feature or "prompt"
|
mask_prompt = getattr(config, "mask_prompt", False)
|
||||||
completion_feature = completion_feature or "completion"
|
prompt_feature = getattr(config, "prompt_feature", "prompt")
|
||||||
|
text_feature = getattr(config, "text_feature", "text")
|
||||||
|
completion_feature = getattr(config, "completion_feature", "completion")
|
||||||
|
chat_feature = getattr(config, "chat_feature", "messages")
|
||||||
sample = data[0]
|
sample = data[0]
|
||||||
if "messages" in sample:
|
if prompt_feature in sample and completion_feature in sample:
|
||||||
return ChatDataset(data, tokenizer)
|
return CompletionsDataset(
|
||||||
elif prompt_feature in sample and completion_feature in sample:
|
data, tokenizer, prompt_feature, completion_feature, mask_prompt
|
||||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
)
|
||||||
elif "text" in sample:
|
elif chat_feature in sample:
|
||||||
return Dataset(data, tokenizer)
|
return ChatDataset(
|
||||||
|
data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
|
||||||
|
)
|
||||||
|
elif text_feature in sample:
|
||||||
|
if mask_prompt:
|
||||||
|
raise ValueError("Prompt masking not supported for text dataset.")
|
||||||
|
return Dataset(data, tokenizer, text_key=text_feature)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported data format, check the supported formats here:\n"
|
"Unsupported data format, check the supported formats here:\n"
|
||||||
@ -106,15 +146,14 @@ def create_dataset(
|
|||||||
def load_local_dataset(
|
def load_local_dataset(
|
||||||
data_path: Path,
|
data_path: Path,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
config,
|
||||||
completion_feature: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
def load_subset(path):
|
def load_subset(path):
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return []
|
return []
|
||||||
with open(path, "r") as fid:
|
with open(path, "r") as fid:
|
||||||
data = [json.loads(l) for l in fid]
|
data = [json.loads(l) for l in fid]
|
||||||
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
|
return create_dataset(data, tokenizer, config)
|
||||||
|
|
||||||
names = ("train", "valid", "test")
|
names = ("train", "valid", "test")
|
||||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||||
@ -124,8 +163,7 @@ def load_local_dataset(
|
|||||||
def load_hf_dataset(
|
def load_hf_dataset(
|
||||||
data_id: str,
|
data_id: str,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
config,
|
||||||
completion_feature: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
from datasets import exceptions, load_dataset
|
from datasets import exceptions, load_dataset
|
||||||
|
|
||||||
@ -136,9 +174,7 @@ def load_hf_dataset(
|
|||||||
|
|
||||||
train, valid, test = [
|
train, valid, test = [
|
||||||
(
|
(
|
||||||
create_dataset(
|
create_dataset(dataset[n], tokenizer, config)
|
||||||
dataset[n], tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
if n in dataset.keys()
|
if n in dataset.keys()
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
@ -154,42 +190,61 @@ def load_hf_dataset(
|
|||||||
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
hf_args = args.hf_dataset
|
def create_hf_dataset(dataset_name, config, split, hf_config):
|
||||||
dataset_name = hf_args["name"]
|
|
||||||
print(f"Loading Hugging Face dataset {dataset_name}.")
|
|
||||||
text_feature = hf_args.get("text_feature")
|
|
||||||
prompt_feature = hf_args.get("prompt_feature")
|
|
||||||
completion_feature = hf_args.get("completion_feature")
|
|
||||||
|
|
||||||
def create_hf_dataset(split: str = None):
|
|
||||||
ds = datasets.load_dataset(
|
ds = datasets.load_dataset(
|
||||||
dataset_name,
|
dataset_name,
|
||||||
split=split,
|
split=split,
|
||||||
**hf_args.get("config", {}),
|
**hf_config,
|
||||||
)
|
|
||||||
if prompt_feature and completion_feature:
|
|
||||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
|
||||||
elif text_feature:
|
|
||||||
return Dataset(ds, tokenizer, text_key=text_feature)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Specify either a prompt and completion feature or a text "
|
|
||||||
"feature for the Hugging Face dataset."
|
|
||||||
)
|
)
|
||||||
|
return create_dataset(ds, tokenizer, config)
|
||||||
|
|
||||||
|
dataset_collection = args.hf_dataset
|
||||||
|
if isinstance(dataset_collection, dict):
|
||||||
|
dataset_collection = [dataset_collection]
|
||||||
|
|
||||||
|
collection = []
|
||||||
|
for ds in dataset_collection:
|
||||||
|
ds_name = ds["name"]
|
||||||
|
print(f"Loading Hugging Face dataset {ds_name}.")
|
||||||
|
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
|
||||||
|
config = types.SimpleNamespace(**ds)
|
||||||
|
hf_config = ds.get("config", {})
|
||||||
if args.train:
|
if args.train:
|
||||||
train_split = hf_args.get("train_split", "train[:80%]")
|
train_split = ds.get("train_split", "train[:80%]")
|
||||||
valid_split = hf_args.get("valid_split", "train[-10%:]")
|
valid_split = ds.get("valid_split", "train[-10%:]")
|
||||||
train = create_hf_dataset(split=train_split)
|
train = create_hf_dataset(
|
||||||
valid = create_hf_dataset(split=valid_split)
|
ds_name,
|
||||||
|
config,
|
||||||
|
train_split,
|
||||||
|
hf_config,
|
||||||
|
)
|
||||||
|
valid = create_hf_dataset(
|
||||||
|
ds_name,
|
||||||
|
config,
|
||||||
|
valid_split,
|
||||||
|
hf_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
train, valid = [], []
|
train, valid = [], []
|
||||||
|
|
||||||
if args.test:
|
if args.test:
|
||||||
test = create_hf_dataset(split=hf_args.get("test_split"))
|
test_split = ds.get("test_split")
|
||||||
|
test = create_hf_dataset(
|
||||||
|
ds_name,
|
||||||
|
config,
|
||||||
|
test_split,
|
||||||
|
hf_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
test = []
|
test = []
|
||||||
|
|
||||||
return train, valid, test
|
collection.append((train, valid, test))
|
||||||
|
|
||||||
|
if len(collection) == 1:
|
||||||
|
return collection[0]
|
||||||
|
|
||||||
|
# Otherwise concatenate them
|
||||||
|
return tuple(map(ConcatenatedDataset, zip(*collection)))
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||||
else:
|
else:
|
||||||
data_path = Path(args.data)
|
data_path = Path(args.data)
|
||||||
|
|
||||||
prompt_feature = getattr(args, "prompt_feature", None)
|
|
||||||
completion_feature = getattr(args, "completion_feature", None)
|
|
||||||
if data_path.exists():
|
if data_path.exists():
|
||||||
train, valid, test = load_local_dataset(
|
train, valid, test = load_local_dataset(data_path, tokenizer, args)
|
||||||
data_path, tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(f"Loading Hugging Face dataset {args.data}.")
|
print(f"Loading Hugging Face dataset {args.data}.")
|
||||||
train, valid, test = load_hf_dataset(
|
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
||||||
args.data, tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.train and len(train) == 0:
|
if args.train and len(train) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -5,13 +5,16 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from .datasets import CompletionsDataset
|
||||||
|
|
||||||
|
|
||||||
def grad_checkpoint(layer):
|
def grad_checkpoint(layer):
|
||||||
@ -63,20 +66,30 @@ class TrainingArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def default_loss(model, batch, lengths):
|
||||||
|
inputs = batch[:, :-1]
|
||||||
|
targets = batch[:, 1:]
|
||||||
|
|
||||||
logits = model(inputs)
|
logits = model(inputs)
|
||||||
logits = logits.astype(mx.float32)
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
steps = mx.arange(1, targets.shape[1] + 1)
|
||||||
|
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||||
|
|
||||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||||
ntoks = length_mask.sum()
|
ntoks = mask.sum()
|
||||||
ce = ce.sum() / ntoks
|
ce = ce.sum() / ntoks
|
||||||
|
|
||||||
return ce, ntoks
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_batches(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
batch_size,
|
||||||
|
max_seq_length,
|
||||||
|
train=False,
|
||||||
|
):
|
||||||
# Sort by length:
|
# Sort by length:
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
indices = np.random.permutation(len(batch_idx))
|
indices = np.random.permutation(len(batch_idx))
|
||||||
for i in indices:
|
for i in indices:
|
||||||
batch = [dataset[j] for j in batch_idx[i]]
|
batch = [dataset[j] for j in batch_idx[i]]
|
||||||
|
if len(batch[0]) == 2:
|
||||||
|
batch, offsets = zip(*batch)
|
||||||
|
else:
|
||||||
|
offsets = [0] * len(batch)
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
if max(lengths) > max_seq_length:
|
if max(lengths) > max_seq_length:
|
||||||
print(
|
print(
|
||||||
@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
truncated_length # Update lengths to match truncated lengths
|
truncated_length # Update lengths to match truncated lengths
|
||||||
)
|
)
|
||||||
batch = mx.array(batch_arr)
|
batch = mx.array(batch_arr)
|
||||||
|
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
|
@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||||
|
|
||||||
def test_hf(self):
|
def test_hf(self):
|
||||||
args = types.SimpleNamespace(
|
hf_args = {
|
||||||
hf_dataset={
|
|
||||||
"name": "billsum",
|
"name": "billsum",
|
||||||
"prompt_feature": "text",
|
"prompt_feature": "text",
|
||||||
"completion_feature": "summary",
|
"completion_feature": "summary",
|
||||||
"train_split": "train[:2%]",
|
"train_split": "train[:2%]",
|
||||||
"valid_split": "train[-2%:]",
|
"valid_split": "train[-2%:]",
|
||||||
},
|
}
|
||||||
|
args = types.SimpleNamespace(
|
||||||
|
hf_dataset=hf_args,
|
||||||
test=False,
|
test=False,
|
||||||
train=True,
|
train=True,
|
||||||
)
|
)
|
||||||
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
|
|||||||
self.assertTrue(len(valid[0]) > 0)
|
self.assertTrue(len(valid[0]) > 0)
|
||||||
self.assertEqual(len(test), 0)
|
self.assertEqual(len(test), 0)
|
||||||
|
|
||||||
|
args = types.SimpleNamespace(
|
||||||
|
hf_dataset=[hf_args, hf_args],
|
||||||
|
test=False,
|
||||||
|
train=True,
|
||||||
|
)
|
||||||
|
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
|
||||||
|
self.assertEqual(2 * len(train), len(train_double))
|
||||||
|
self.assertEqual(2 * len(valid), len(valid_double))
|
||||||
|
self.assertEqual(2 * len(test), len(test_double))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user