mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00

Fixes #484 Add support for instruct tuning with input/output pairs and alternative loss functions. * **llms/mlx_lm/lora.py** - Add `CompletionsDataset` class to support input/output pairs. - Modify `Dataset` class to handle different dataset types. - Update `main` function to include new dataset type. * **llms/mlx_lm/tuner/trainer.py** - Modify `default_loss` function to support alternative loss functions. - Add new `instruct_loss` function for instruct tuning. * **llms/mlx_lm/LORA.md** - Add instructions for instruct tuning with input/output pairs. - Update documentation to include alternative loss functions. * **llms/tests/test_datasets.py** - Add tests for `CompletionsDataset` and `create_dataset` functions. * **llms/tests/test_trainer.py** - Add tests for `default_loss` and `instruct_loss` functions. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/484?shareId=XXXX-XXXX-XXXX-XXXX).
55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
import unittest
|
|
from transformers import PreTrainedTokenizerFast
|
|
from llms.mlx_lm.tuner.datasets import CompletionsDataset, create_dataset
|
|
|
|
class TestCompletionsDataset(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
|
|
self.data = [
|
|
{"prompt": "What is the capital of France?", "completion": "Paris."},
|
|
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
|
|
]
|
|
|
|
def test_completions_dataset(self):
|
|
dataset = CompletionsDataset(self.data, self.tokenizer, "prompt", "completion")
|
|
self.assertEqual(len(dataset), 2)
|
|
self.assertTrue(isinstance(dataset[0], list))
|
|
self.assertTrue(isinstance(dataset[1], list))
|
|
|
|
class TestCreateDataset(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
|
|
self.data_completions = [
|
|
{"prompt": "What is the capital of France?", "completion": "Paris."},
|
|
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
|
|
]
|
|
self.data_text = [
|
|
{"text": "This is a sample text."},
|
|
{"text": "This is another sample text."}
|
|
]
|
|
self.data_chat = [
|
|
{"messages": [{"role": "user", "content": "Hello."}, {"role": "assistant", "content": "Hi there!"}]}
|
|
]
|
|
|
|
def test_create_completions_dataset(self):
|
|
dataset = create_dataset(self.data_completions, self.tokenizer, "prompt", "completion")
|
|
self.assertEqual(len(dataset), 2)
|
|
self.assertTrue(isinstance(dataset[0], list))
|
|
self.assertTrue(isinstance(dataset[1], list))
|
|
|
|
def test_create_text_dataset(self):
|
|
dataset = create_dataset(self.data_text, self.tokenizer)
|
|
self.assertEqual(len(dataset), 2)
|
|
self.assertTrue(isinstance(dataset[0], list))
|
|
self.assertTrue(isinstance(dataset[1], list))
|
|
|
|
def test_create_chat_dataset(self):
|
|
dataset = create_dataset(self.data_chat, self.tokenizer)
|
|
self.assertEqual(len(dataset), 1)
|
|
self.assertTrue(isinstance(dataset[0], list))
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|