mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Configuration-based use of HF hub-hosted datasets for training (#701)
* Add hf_dataset configuration for using HF hub-hosted datasets for (Q)LoRA training * Pre-commit formatting * Fix YAML config example * Print DS info * Include name * Add hf_dataset parameter default * Remove TextHFDataset and CompletionsHFDataset and use Dataset and CompletionsDataset instead, adding a text_key constructor argument to the former (and changing it to work with a provided data structure instead of just from a JSON file), and prompt_key and completion_key arguments to the latter with defaults for backwards compatibility. * nits * update docs --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -76,6 +76,24 @@ class TestDatasets(unittest.TestCase):
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||
|
||||
def test_hf(self):
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset={
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
},
|
||||
test=False,
|
||||
train=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||
self.assertTrue(len(train) > 0)
|
||||
self.assertTrue(len(train[0]) > 0)
|
||||
self.assertTrue(len(valid) > 0)
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertEqual(len(test), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user