mlx-examples/llms/tests/test_datasets.py

55 lines
2.1 KiB
Python
Raw Normal View History

import unittest
from transformers import PreTrainedTokenizerFast
from llms.mlx_lm.tuner.datasets import CompletionsDataset, create_dataset
class TestCompletionsDataset(unittest.TestCase):
def setUp(self):
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
self.data = [
{"prompt": "What is the capital of France?", "completion": "Paris."},
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
]
def test_completions_dataset(self):
dataset = CompletionsDataset(self.data, self.tokenizer, "prompt", "completion")
self.assertEqual(len(dataset), 2)
self.assertTrue(isinstance(dataset[0], list))
self.assertTrue(isinstance(dataset[1], list))
class TestCreateDataset(unittest.TestCase):
def setUp(self):
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
self.data_completions = [
{"prompt": "What is the capital of France?", "completion": "Paris."},
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
]
self.data_text = [
{"text": "This is a sample text."},
{"text": "This is another sample text."}
]
self.data_chat = [
{"messages": [{"role": "user", "content": "Hello."}, {"role": "assistant", "content": "Hi there!"}]}
]
def test_create_completions_dataset(self):
dataset = create_dataset(self.data_completions, self.tokenizer, "prompt", "completion")
self.assertEqual(len(dataset), 2)
self.assertTrue(isinstance(dataset[0], list))
self.assertTrue(isinstance(dataset[1], list))
def test_create_text_dataset(self):
dataset = create_dataset(self.data_text, self.tokenizer)
self.assertEqual(len(dataset), 2)
self.assertTrue(isinstance(dataset[0], list))
self.assertTrue(isinstance(dataset[1], list))
def test_create_chat_dataset(self):
dataset = create_dataset(self.data_chat, self.tokenizer)
self.assertEqual(len(dataset), 1)
self.assertTrue(isinstance(dataset[0], list))
if __name__ == "__main__":
unittest.main()