mlx-examples/llms/tests/test_trainer.py
Anupam Mediratta 5c89d1f6a6 Add instruct tuning support to LoRA training
Fixes #484

Add support for instruct tuning with input/output pairs and alternative loss functions.

* **llms/mlx_lm/lora.py**
  - Add `CompletionsDataset` class to support input/output pairs.
  - Modify `Dataset` class to handle different dataset types.
  - Update `main` function to include new dataset type.

* **llms/mlx_lm/tuner/trainer.py**
  - Modify `default_loss` function to support alternative loss functions.
  - Add new `instruct_loss` function for instruct tuning.

* **llms/mlx_lm/LORA.md**
  - Add instructions for instruct tuning with input/output pairs.
  - Update documentation to include alternative loss functions.

* **llms/tests/test_datasets.py**
  - Add tests for `CompletionsDataset` and `create_dataset` functions.

* **llms/tests/test_trainer.py**
  - Add tests for `default_loss` and `instruct_loss` functions.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/484?shareId=XXXX-XXXX-XXXX-XXXX).
2025-01-20 11:42:13 +05:30

40 lines
1.5 KiB
Python

import unittest
import numpy as np
import mlx.nn as nn
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizerFast
from llms.mlx_lm.tuner.trainer import default_loss, instruct_loss
class TestLossFunctions(unittest.TestCase):
def setUp(self):
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
self.model = nn.Module()
self.inputs = np.array([[1, 2, 3], [4, 5, 6]])
self.targets = np.array([[1, 2, 3], [4, 5, 6]])
self.lengths = np.array([3, 3])
def test_default_loss(self):
loss, ntoks = default_loss(self.model, self.inputs, self.targets, self.lengths)
self.assertIsInstance(loss, nn.Tensor)
self.assertIsInstance(ntoks, nn.Tensor)
def test_instruct_loss(self):
loss, ntoks = instruct_loss(self.model, self.inputs, self.targets, self.lengths)
self.assertIsInstance(loss, nn.Tensor)
self.assertIsInstance(ntoks, nn.Tensor)
def test_instruct_loss_with_masking(self):
loss, ntoks = instruct_loss(self.model, self.inputs, self.targets, self.lengths, mask_input=True)
self.assertIsInstance(loss, nn.Tensor)
self.assertIsInstance(ntoks, nn.Tensor)
def test_instruct_loss_without_masking(self):
loss, ntoks = instruct_loss(self.model, self.inputs, self.targets, self.lengths, mask_input=False)
self.assertIsInstance(loss, nn.Tensor)
self.assertIsInstance(ntoks, nn.Tensor)
if __name__ == "__main__":
unittest.main()