mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
Completion only fine-tuning of instruction models with collections of HF datasets (#1103)
- Optional completion only fine-tuning with `--mask-prompt` - Collections of Hugging Face datasets --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
|
||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||
|
||||
def test_hf(self):
|
||||
hf_args = {
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
"train_split": "train[:2%]",
|
||||
"valid_split": "train[-2%:]",
|
||||
}
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset={
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
"train_split": "train[:2%]",
|
||||
"valid_split": "train[-2%:]",
|
||||
},
|
||||
hf_dataset=hf_args,
|
||||
test=False,
|
||||
train=True,
|
||||
)
|
||||
@@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertEqual(len(test), 0)
|
||||
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset=[hf_args, hf_args],
|
||||
test=False,
|
||||
train=True,
|
||||
)
|
||||
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(2 * len(train), len(train_double))
|
||||
self.assertEqual(2 * len(valid), len(valid_double))
|
||||
self.assertEqual(2 * len(test), len(test_double))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user