mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
![]() |
import unittest
|
||
|
from transformers import PreTrainedTokenizerFast
|
||
|
from llms.mlx_lm.tuner.datasets import CompletionsDataset, create_dataset
|
||
|
|
||
|
class TestCompletionsDataset(unittest.TestCase):
|
||
|
|
||
|
def setUp(self):
|
||
|
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
|
||
|
self.data = [
|
||
|
{"prompt": "What is the capital of France?", "completion": "Paris."},
|
||
|
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
|
||
|
]
|
||
|
|
||
|
def test_completions_dataset(self):
|
||
|
dataset = CompletionsDataset(self.data, self.tokenizer, "prompt", "completion")
|
||
|
self.assertEqual(len(dataset), 2)
|
||
|
self.assertTrue(isinstance(dataset[0], list))
|
||
|
self.assertTrue(isinstance(dataset[1], list))
|
||
|
|
||
|
class TestCreateDataset(unittest.TestCase):
|
||
|
|
||
|
def setUp(self):
|
||
|
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
|
||
|
self.data_completions = [
|
||
|
{"prompt": "What is the capital of France?", "completion": "Paris."},
|
||
|
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
|
||
|
]
|
||
|
self.data_text = [
|
||
|
{"text": "This is a sample text."},
|
||
|
{"text": "This is another sample text."}
|
||
|
]
|
||
|
self.data_chat = [
|
||
|
{"messages": [{"role": "user", "content": "Hello."}, {"role": "assistant", "content": "Hi there!"}]}
|
||
|
]
|
||
|
|
||
|
def test_create_completions_dataset(self):
|
||
|
dataset = create_dataset(self.data_completions, self.tokenizer, "prompt", "completion")
|
||
|
self.assertEqual(len(dataset), 2)
|
||
|
self.assertTrue(isinstance(dataset[0], list))
|
||
|
self.assertTrue(isinstance(dataset[1], list))
|
||
|
|
||
|
def test_create_text_dataset(self):
|
||
|
dataset = create_dataset(self.data_text, self.tokenizer)
|
||
|
self.assertEqual(len(dataset), 2)
|
||
|
self.assertTrue(isinstance(dataset[0], list))
|
||
|
self.assertTrue(isinstance(dataset[1], list))
|
||
|
|
||
|
def test_create_chat_dataset(self):
|
||
|
dataset = create_dataset(self.data_chat, self.tokenizer)
|
||
|
self.assertEqual(len(dataset), 1)
|
||
|
self.assertTrue(isinstance(dataset[0], list))
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|