Completion only fine-tuning of instruction models with collections of HF datasets (#1103)

- Optional completion only fine-tuning with `--mask-prompt`
- Collections of Hugging Face datasets

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2025-02-09 23:12:34 -05:00
committed by GitHub
parent 1ced1b00ca
commit 5865899c81
6 changed files with 199 additions and 85 deletions

View File

@@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self):
hf_args = {
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
}
args = types.SimpleNamespace(
hf_dataset={
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
},
hf_dataset=hf_args,
test=False,
train=True,
)
@@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0)
args = types.SimpleNamespace(
hf_dataset=[hf_args, hf_args],
test=False,
train=True,
)
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
self.assertEqual(2 * len(train), len(train_double))
self.assertEqual(2 * len(valid), len(valid_double))
self.assertEqual(2 * len(test), len(test_double))
if __name__ == "__main__":
unittest.main()