mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:22:46 +08:00

Fixes #484 Add support for instruct tuning with input/output pairs and alternative loss functions. * **llms/mlx_lm/lora.py** - Add `CompletionsDataset` class to support input/output pairs. - Modify `Dataset` class to handle different dataset types. - Update `main` function to include new dataset type. * **llms/mlx_lm/tuner/trainer.py** - Modify `default_loss` function to support alternative loss functions. - Add new `instruct_loss` function for instruct tuning. * **llms/mlx_lm/LORA.md** - Add instructions for instruct tuning with input/output pairs. - Update documentation to include alternative loss functions. * **llms/tests/test_datasets.py** - Add tests for `CompletionsDataset` and `create_dataset` functions. * **llms/tests/test_trainer.py** - Add tests for `default_loss` and `instruct_loss` functions. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/484?shareId=XXXX-XXXX-XXXX-XXXX).
40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
import unittest
|
|
import numpy as np
|
|
import mlx.nn as nn
|
|
from mlx.nn.utils import average_gradients
|
|
from mlx.utils import tree_flatten
|
|
from transformers import PreTrainedTokenizerFast
|
|
from llms.mlx_lm.tuner.trainer import default_loss, instruct_loss
|
|
|
|
class TestLossFunctions(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
|
|
self.model = nn.Module()
|
|
self.inputs = np.array([[1, 2, 3], [4, 5, 6]])
|
|
self.targets = np.array([[1, 2, 3], [4, 5, 6]])
|
|
self.lengths = np.array([3, 3])
|
|
|
|
def test_default_loss(self):
|
|
loss, ntoks = default_loss(self.model, self.inputs, self.targets, self.lengths)
|
|
self.assertIsInstance(loss, nn.Tensor)
|
|
self.assertIsInstance(ntoks, nn.Tensor)
|
|
|
|
def test_instruct_loss(self):
|
|
loss, ntoks = instruct_loss(self.model, self.inputs, self.targets, self.lengths)
|
|
self.assertIsInstance(loss, nn.Tensor)
|
|
self.assertIsInstance(ntoks, nn.Tensor)
|
|
|
|
def test_instruct_loss_with_masking(self):
|
|
loss, ntoks = instruct_loss(self.model, self.inputs, self.targets, self.lengths, mask_input=True)
|
|
self.assertIsInstance(loss, nn.Tensor)
|
|
self.assertIsInstance(ntoks, nn.Tensor)
|
|
|
|
def test_instruct_loss_without_masking(self):
|
|
loss, ntoks = instruct_loss(self.model, self.inputs, self.targets, self.lengths, mask_input=False)
|
|
self.assertIsInstance(loss, nn.Tensor)
|
|
self.assertIsInstance(ntoks, nn.Tensor)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|