mlx-examples/llms/tests/test_datasets.py
Anupam Mediratta 5c89d1f6a6 Add instruct tuning support to LoRA training
Fixes #484

Add support for instruct tuning with input/output pairs and alternative loss functions.

* **llms/mlx_lm/lora.py**
  - Add `CompletionsDataset` class to support input/output pairs.
  - Modify `Dataset` class to handle different dataset types.
  - Update `main` function to include new dataset type.

* **llms/mlx_lm/tuner/trainer.py**
  - Modify `default_loss` function to support alternative loss functions.
  - Add new `instruct_loss` function for instruct tuning.

* **llms/mlx_lm/LORA.md**
  - Add instructions for instruct tuning with input/output pairs.
  - Update documentation to include alternative loss functions.

* **llms/tests/test_datasets.py**
  - Add tests for `CompletionsDataset` and `create_dataset` functions.

* **llms/tests/test_trainer.py**
  - Add tests for `default_loss` and `instruct_loss` functions.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/484?shareId=XXXX-XXXX-XXXX-XXXX).
2025-01-20 11:42:13 +05:30

55 lines
2.1 KiB
Python

import unittest
from transformers import PreTrainedTokenizerFast
from llms.mlx_lm.tuner.datasets import CompletionsDataset, create_dataset
class TestCompletionsDataset(unittest.TestCase):
def setUp(self):
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
self.data = [
{"prompt": "What is the capital of France?", "completion": "Paris."},
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
]
def test_completions_dataset(self):
dataset = CompletionsDataset(self.data, self.tokenizer, "prompt", "completion")
self.assertEqual(len(dataset), 2)
self.assertTrue(isinstance(dataset[0], list))
self.assertTrue(isinstance(dataset[1], list))
class TestCreateDataset(unittest.TestCase):
def setUp(self):
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
self.data_completions = [
{"prompt": "What is the capital of France?", "completion": "Paris."},
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
]
self.data_text = [
{"text": "This is a sample text."},
{"text": "This is another sample text."}
]
self.data_chat = [
{"messages": [{"role": "user", "content": "Hello."}, {"role": "assistant", "content": "Hi there!"}]}
]
def test_create_completions_dataset(self):
dataset = create_dataset(self.data_completions, self.tokenizer, "prompt", "completion")
self.assertEqual(len(dataset), 2)
self.assertTrue(isinstance(dataset[0], list))
self.assertTrue(isinstance(dataset[1], list))
def test_create_text_dataset(self):
dataset = create_dataset(self.data_text, self.tokenizer)
self.assertEqual(len(dataset), 2)
self.assertTrue(isinstance(dataset[0], list))
self.assertTrue(isinstance(dataset[1], list))
def test_create_chat_dataset(self):
dataset = create_dataset(self.data_chat, self.tokenizer)
self.assertEqual(len(dataset), 1)
self.assertTrue(isinstance(dataset[0], list))
if __name__ == "__main__":
unittest.main()