mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Module checks the weight on load_weights (#337)
				
					
				
			* update module to check weights on load, also fix docs and reorganize tests * nits + rebase * a few more docs updates for Module * use manual module file * comment
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import textwrap | ||||
| from typing import Any, Callable, List, Optional, Union | ||||
| from typing import Any, Callable, List, Optional, Tuple, Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.utils import tree_flatten, tree_unflatten | ||||
| @@ -61,6 +61,7 @@ class Module(dict): | ||||
|  | ||||
|     @property | ||||
|     def training(self): | ||||
|         """Boolean indicating if the model is in training mode.""" | ||||
|         return self._training | ||||
|  | ||||
|     def _extra_repr(self): | ||||
| @@ -87,15 +88,83 @@ class Module(dict): | ||||
|     def __setattr__(self, key: str, val: Any): | ||||
|         self[key] = val | ||||
|  | ||||
|     def load_weights(self, file: str): | ||||
|     def load_weights( | ||||
|         self, | ||||
|         file_or_weights: Union[str, List[Tuple[str, mx.array]]], | ||||
|         strict: bool = True, | ||||
|     ): | ||||
|         """ | ||||
|         Load and update the model's weights from a `.npz` file. | ||||
|         Update the model's weights from a ``.npz`` or a list. | ||||
|  | ||||
|         Args: | ||||
|             file_or_weights (str or list(tuple(str, mx.array))): The path to | ||||
|                 the weights ``.npz`` file or a list of pairs of parameter names | ||||
|                 and arrays. | ||||
|             strict (bool, optional): If ``True`` then checks that the provided | ||||
|               weights exactly match the parameters of the model. Otherwise, | ||||
|               only the weights actually contained in the model are loaded and | ||||
|               shapes are not checked. Default: ``True``. | ||||
|  | ||||
|         Example: | ||||
|  | ||||
|             .. code-block:: python | ||||
|  | ||||
|                 import mlx.core as mx | ||||
|                 import mlx.nn as nn | ||||
|                 model = nn.Linear(10, 10) | ||||
|  | ||||
|                 # Load from file | ||||
|                 model.load_weights("weights.npz") | ||||
|  | ||||
|                 # Load from list | ||||
|                 weights = [ | ||||
|                     ("weight", mx.random.uniform(shape=(10, 10))), | ||||
|                     ("bias",  mx.zeros((10,))), | ||||
|                 ] | ||||
|                 model.load_weights(weights) | ||||
|  | ||||
|                 # Missing weight | ||||
|                 weights = [ | ||||
|                     ("weight", mx.random.uniform(shape=(10, 10))), | ||||
|                 ] | ||||
|  | ||||
|                 # Raises a ValueError exception | ||||
|                 model.load_weights(weights) | ||||
|  | ||||
|                 # Ok, only updates the weight but not the bias | ||||
|                 model.load_weights(weights, strict=False) | ||||
|         """ | ||||
|         self.update(tree_unflatten(list(mx.load(file).items()))) | ||||
|         weights = file_or_weights | ||||
|         if isinstance(weights, str): | ||||
|             weights = list(mx.load(weights).items()) | ||||
|  | ||||
|         if strict: | ||||
|             new_weights = dict(weights) | ||||
|             curr_weights = dict(tree_flatten(self.parameters())) | ||||
|             if extras := (new_weights.keys() - curr_weights.keys()): | ||||
|                 extras = " ".join(extras) | ||||
|                 raise ValueError(f"Received parameters not in model: {extras}.") | ||||
|             if missing := (curr_weights.keys() - new_weights.keys()): | ||||
|                 missing = " ".join(missing) | ||||
|                 raise ValueError(f"Missing parameters: {missing}.") | ||||
|             for k, v in curr_weights.items(): | ||||
|                 v_new = new_weights[k] | ||||
|                 if not isinstance(v_new, mx.array): | ||||
|                     raise ValueError( | ||||
|                         "Expected mx.array but received " | ||||
|                         f"{type(v_new)} for parameter {k}" | ||||
|                     ) | ||||
|                 if v_new.shape != v.shape: | ||||
|                     raise ValueError( | ||||
|                         f"Expected shape {v.shape} but received " | ||||
|                         f" shape {v_new.shape} for parameter {k}" | ||||
|                     ) | ||||
|  | ||||
|         self.update(tree_unflatten(weights)) | ||||
|  | ||||
|     def save_weights(self, file: str): | ||||
|         """ | ||||
|         Save the model's weights to a `.npz` file. | ||||
|         Save the model's weights to a ``.npz`` file. | ||||
|         """ | ||||
|         mx.savez(file, **dict(tree_flatten(self.parameters()))) | ||||
|  | ||||
| @@ -351,23 +420,26 @@ class Module(dict): | ||||
|         """Freeze the Module's parameters or some of them. Freezing a parameter means not | ||||
|         computing gradients for it. | ||||
|  | ||||
|         This function is idempotent ie freezing a frozen model is a noop. | ||||
|         This function is idempotent i.e. freezing a frozen model is a no-op. | ||||
|  | ||||
|         For instance to only train the attention parameters from a transformer: | ||||
|         Example: | ||||
|             For instance to only train the attention parameters from a Transformer: | ||||
|  | ||||
|             model = ... | ||||
|             model.freeze() | ||||
|             model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None) | ||||
|             .. code-block:: python | ||||
|  | ||||
|                 model = nn.Transformer() | ||||
|                 model.freeze() | ||||
|                 model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None) | ||||
|  | ||||
|         Args: | ||||
|             recurse (bool, optional): If True then freeze the parameters of the | ||||
|                 submodules as well (default: True). | ||||
|                 submodules as well. Default: ``True``. | ||||
|             keys (str or list[str], optional): If provided then only these | ||||
|                 parameters will be frozen otherwise all the parameters of a | ||||
|                 module. For instance freeze all biases by calling | ||||
|                 ``module.freeze(keys="bias")``. | ||||
|             strict (bool, optional): If set to True validate that the passed keys exist | ||||
|                 (default: False). | ||||
|             strict (bool, optional): If set to ``True`` validate that the passed keys exist. | ||||
|                 Default: ``False``. | ||||
|         """ | ||||
|  | ||||
|         def _freeze_impl(_, m): | ||||
| @@ -401,21 +473,25 @@ class Module(dict): | ||||
|         This function is idempotent ie unfreezing a model that is not frozen is | ||||
|         a noop. | ||||
|  | ||||
|         For instance to only train the biases one can do: | ||||
|         Example: | ||||
|  | ||||
|             model = ... | ||||
|             model.freeze() | ||||
|             model.unfreeze(keys="bias") | ||||
|             For instance to only train the biases of a Transformer one can do: | ||||
|  | ||||
|             .. code-block:: python | ||||
|  | ||||
|                 model = nn.Transformer() | ||||
|                 model.freeze() | ||||
|                 model.unfreeze(keys="bias") | ||||
|  | ||||
|         Args: | ||||
|             recurse (bool, optional): If True then unfreeze the parameters of the | ||||
|                 submodules as well (default: True). | ||||
|                 submodules as well. Default: ``True``. | ||||
|             keys (str or list[str], optional): If provided then only these | ||||
|                 parameters will be unfrozen otherwise all the parameters of a | ||||
|                 module. For instance unfreeze all biases by calling | ||||
|                 ``module.unfreeze(keys="bias")``. | ||||
|             strict (bool, optional): If set to True validate that the passed keys exist | ||||
|                 (default: False). | ||||
|             strict (bool, optional): If set to ``True`` validate that the passed keys exist. | ||||
|                 Default: ``False``. | ||||
|         """ | ||||
|  | ||||
|         def _unfreeze_impl(_, m): | ||||
| @@ -432,10 +508,25 @@ class Module(dict): | ||||
|             _unfreeze_impl("", self) | ||||
|  | ||||
|     def train(self, mode: bool = True): | ||||
|         """Set the model in or out of training mode. | ||||
|  | ||||
|         Training mode only applies to certain layers. For example | ||||
|         :obj:`Dropout` applies a random mask in training mode, but is the | ||||
|         identity in evaluation mode. | ||||
|  | ||||
|         Args: | ||||
|             mode (bool): Indicate if the model should be in training or | ||||
|                 evaluation mode. Default: ``True``. | ||||
|         """ | ||||
|  | ||||
|         def _set_train(_, m): | ||||
|             m._training = mode | ||||
|  | ||||
|         self.apply_to_modules(_set_train) | ||||
|  | ||||
|     def eval(self): | ||||
|         """Set the model to evaluation mode. | ||||
|  | ||||
|         See :func:`train`. | ||||
|         """ | ||||
|         self.train(False) | ||||
|   | ||||
							
								
								
									
										279
									
								
								python/tests/test_losses.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										279
									
								
								python/tests/test_losses.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,279 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx_tests | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| class TestLosses(mlx_tests.MLXTestCase): | ||||
|     def test_cross_entropy(self): | ||||
|         logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) | ||||
|         targets = mx.array([0, 1]) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.cross_entropy(logits, targets, reduction="none") | ||||
|         expected_none = mx.array([0.0, 0.0]) | ||||
|         self.assertTrue(mx.array_equal(losses_none, expected_none)) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean") | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertEqual(losses_mean, expected_mean) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum") | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertEqual(losses_sum, expected_sum) | ||||
|  | ||||
|         # Test cases with weights and no label smoothing | ||||
|         logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) | ||||
|         targets = mx.array([0, 1]) | ||||
|         weights = mx.array([1.0, 2.0]) | ||||
|  | ||||
|         # Reduction 'none' | ||||
|         losses_none = nn.losses.cross_entropy( | ||||
|             logits, | ||||
|             targets, | ||||
|             weights=weights, | ||||
|             reduction="none", | ||||
|         ) | ||||
|         expected_none = mx.array([0.04858735, 0.0971747])  # Calculated losses | ||||
|         self.assertTrue( | ||||
|             np.allclose(losses_none, expected_none, atol=1e-5), | ||||
|             "Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]", | ||||
|         ) | ||||
|  | ||||
|         # Reduction 'mean' | ||||
|         losses_mean = nn.losses.cross_entropy( | ||||
|             logits, | ||||
|             targets, | ||||
|             weights=weights, | ||||
|             reduction="mean", | ||||
|         ) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertTrue( | ||||
|             np.allclose(losses_mean, expected_mean, atol=1e-5), | ||||
|             "Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]", | ||||
|         ) | ||||
|  | ||||
|         # Reduction 'sum' | ||||
|         losses_sum = nn.losses.cross_entropy( | ||||
|             logits, | ||||
|             targets, | ||||
|             weights=weights, | ||||
|             reduction="sum", | ||||
|         ) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertTrue( | ||||
|             np.allclose(losses_sum, expected_sum, atol=1e-5), | ||||
|             "Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]", | ||||
|         ) | ||||
|  | ||||
|         # Test case with equal weights and label smoothing > 0 | ||||
|         logits = mx.array( | ||||
|             [[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]] | ||||
|         ) | ||||
|         target = mx.array([2, 1, 0]) | ||||
|  | ||||
|         losses_none = nn.losses.cross_entropy( | ||||
|             logits, target, label_smoothing=0.3, reduction="none" | ||||
|         ) | ||||
|         expected_none = mx.array([1.29693, 1.38617, 1.48176]) | ||||
|         self.assertTrue( | ||||
|             mx.allclose(expected_none, losses_none), | ||||
|             "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'", | ||||
|         ) | ||||
|  | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         losses_mean = nn.losses.cross_entropy( | ||||
|             logits, target, label_smoothing=0.3, reduction="mean" | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose(losses_mean, expected_mean), | ||||
|             "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'", | ||||
|         ) | ||||
|  | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         losses_sum = nn.losses.cross_entropy( | ||||
|             logits, target, label_smoothing=0.3, reduction="sum" | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose(losses_sum, expected_sum), | ||||
|             "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'", | ||||
|         ) | ||||
|  | ||||
|     def test_l1_loss(self): | ||||
|         predictions = mx.array([0.5, 0.2, 0.9, 0.0]) | ||||
|         targets = mx.array([0.5, 0.2, 0.9, 0.0]) | ||||
|  | ||||
|         # Expected result | ||||
|         expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|  | ||||
|         losses = nn.losses.l1_loss(predictions, targets, reduction="none") | ||||
|         self.assertTrue( | ||||
|             mx.array_equal(losses, expected_none), | ||||
|             "Test failed for l1_loss --reduction='none'", | ||||
|         ) | ||||
|  | ||||
|         losses = nn.losses.l1_loss(predictions, targets, reduction="sum") | ||||
|         self.assertTrue(mx.array_equal(losses, expected_sum)) | ||||
|  | ||||
|         losses = nn.losses.l1_loss(predictions, targets, reduction="mean") | ||||
|         self.assertTrue(mx.array_equal(losses, expected_mean)) | ||||
|  | ||||
|     def test_mse_loss(self): | ||||
|         predictions = mx.array([0.5, 0.2, 0.9, 0.0]) | ||||
|         targets = mx.array([0.7, 0.1, 0.8, 0.2]) | ||||
|  | ||||
|         expected_none = mx.array([0.04, 0.01, 0.01, 0.04]) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.mse_loss(predictions, targets, reduction="none") | ||||
|         self.assertTrue( | ||||
|             np.allclose(losses_none, expected_none, 1e-5), | ||||
|             "Test case failed for mse_loss --reduction='none'", | ||||
|         ) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean") | ||||
|         self.assertEqual( | ||||
|             losses_mean, | ||||
|             expected_mean, | ||||
|             "Test case failed for mse_loss --reduction='mean'", | ||||
|         ) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum") | ||||
|         self.assertEqual( | ||||
|             losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'" | ||||
|         ) | ||||
|  | ||||
|     def test_smooth_l1_loss(self): | ||||
|         predictions = mx.array([1.5, 2.5, 0.5, 3.5]) | ||||
|         targets = mx.array([1.0, 2.0, 0.5, 2.5]) | ||||
|         beta = 1.0 | ||||
|  | ||||
|         # Expected results | ||||
|         expected_none = mx.array([0.125, 0.125, 0.0, 0.5]) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         loss_none = nn.losses.smooth_l1_loss( | ||||
|             predictions, targets, beta, reduction="none" | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.array_equal(loss_none, expected_none), | ||||
|             "Test case failed for smooth_l1_loss --reduction='none'", | ||||
|         ) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum") | ||||
|         self.assertEqual( | ||||
|             loss_sum, | ||||
|             expected_sum, | ||||
|             "Test case failed for smooth_l1_loss --reduction='sum'", | ||||
|         ) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         loss_mean = nn.losses.smooth_l1_loss( | ||||
|             predictions, targets, beta, reduction="mean" | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             loss_mean, | ||||
|             expected_mean, | ||||
|             "Test case failed for smooth_l1_loss --reduction='mean'", | ||||
|         ) | ||||
|  | ||||
|     def test_nll_loss(self): | ||||
|         logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) | ||||
|         targets = mx.array([0, 1]) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.nll_loss(logits, targets, reduction="none") | ||||
|         expected_none = mx.array([0.0, 0.0]) | ||||
|         self.assertTrue(mx.array_equal(losses_none, expected_none)) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean") | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertEqual(losses_mean, expected_mean) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum") | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertEqual(losses_sum, expected_sum) | ||||
|  | ||||
|     def test_kl_div_loss(self): | ||||
|         p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]])) | ||||
|         q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]])) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none") | ||||
|         expected_none = mx.array([0.0, 0.831777]) | ||||
|         self.assertTrue(mx.allclose(losses_none, expected_none)) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean") | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_mean, expected_mean)) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum") | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_sum, expected_sum)) | ||||
|  | ||||
|     def test_triplet_loss(self): | ||||
|         anchors = mx.array([[1, 2, 3], [1, 2, 3]]) | ||||
|         positives = mx.array([[4, 5, 6], [0, -1, 2]]) | ||||
|         negatives = mx.array([[7, 8, 9], [3, 2, 3]]) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.triplet_loss( | ||||
|             anchors, positives, negatives, reduction="none" | ||||
|         ) | ||||
|         expected_none = mx.array([0, 2.31662]) | ||||
|         self.assertTrue(mx.allclose(losses_none, expected_none)) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.triplet_loss( | ||||
|             anchors, positives, negatives, reduction="mean" | ||||
|         ) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_mean, expected_mean)) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.triplet_loss( | ||||
|             anchors, positives, negatives, reduction="sum" | ||||
|         ) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_sum, expected_sum)) | ||||
|  | ||||
|     def test_hinge_loss(self): | ||||
|         inputs = mx.ones((2, 4)) | ||||
|         targets = mx.zeros((2, 4)) | ||||
|         loss = nn.losses.hinge_loss(inputs, targets, reduction="mean") | ||||
|         self.assertEqual(loss, 1.0) | ||||
|  | ||||
|     def test_huber_loss(self): | ||||
|         inputs = mx.ones((2, 4)) | ||||
|         targets = mx.zeros((2, 4)) | ||||
|         loss = nn.losses.huber_loss(inputs, targets, reduction="mean") | ||||
|         self.assertEqual(loss, 0.5) | ||||
|  | ||||
|     def test_log_cosh_loss(self): | ||||
|         inputs = mx.ones((2, 4)) | ||||
|         targets = mx.zeros((2, 4)) | ||||
|         loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") | ||||
|         self.assertAlmostEqual(loss.item(), 0.433781, places=6) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
| @@ -11,7 +11,112 @@ import numpy as np | ||||
| from mlx.utils import tree_flatten, tree_map, tree_unflatten | ||||
|  | ||||
|  | ||||
| class TestNN(mlx_tests.MLXTestCase): | ||||
| class TestBase(mlx_tests.MLXTestCase): | ||||
|     def test_module_utilities(self): | ||||
|         m = nn.Sequential( | ||||
|             nn.Sequential(nn.Linear(2, 10), nn.relu), | ||||
|             nn.Sequential(nn.Linear(10, 10), nn.ReLU()), | ||||
|             nn.Linear(10, 1), | ||||
|             mx.sigmoid, | ||||
|         ) | ||||
|  | ||||
|         children = m.children() | ||||
|         self.assertTrue(isinstance(children, dict)) | ||||
|         self.assertEqual(len(children), 1) | ||||
|         self.assertTrue(isinstance(children["layers"], list)) | ||||
|         self.assertEqual(len(children["layers"]), 4) | ||||
|         self.assertEqual(children["layers"][3], {}) | ||||
|         flat_children = tree_flatten(children, is_leaf=nn.Module.is_module) | ||||
|         self.assertEqual(len(flat_children), 3) | ||||
|  | ||||
|         leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module) | ||||
|         self.assertEqual(len(leaves), 4) | ||||
|         self.assertEqual(leaves[0][0], "layers.0.layers.0") | ||||
|         self.assertEqual(leaves[1][0], "layers.1.layers.0") | ||||
|         self.assertEqual(leaves[2][0], "layers.1.layers.1") | ||||
|         self.assertEqual(leaves[3][0], "layers.2") | ||||
|         self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) | ||||
|         self.assertTrue(leaves[1][1] is m.layers[1].layers[0]) | ||||
|         self.assertTrue(leaves[2][1] is m.layers[1].layers[1]) | ||||
|         self.assertTrue(leaves[3][1] is m.layers[2]) | ||||
|  | ||||
|         m.eval() | ||||
|  | ||||
|         def assert_not_training(k, m): | ||||
|             self.assertFalse(m.training) | ||||
|  | ||||
|         m.apply_to_modules(assert_not_training) | ||||
|  | ||||
|         m.train() | ||||
|  | ||||
|         def assert_training(k, m): | ||||
|             self.assertTrue(m.training) | ||||
|  | ||||
|         m.apply_to_modules(assert_training) | ||||
|  | ||||
|     def test_io(self): | ||||
|         def make_model(): | ||||
|             return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) | ||||
|  | ||||
|         m = make_model() | ||||
|         tdir = tempfile.TemporaryDirectory() | ||||
|         file = os.path.join(tdir.name, "model.npz") | ||||
|         m.save_weights(file) | ||||
|         m_load = make_model() | ||||
|         m_load.load_weights(file) | ||||
|         tdir.cleanup() | ||||
|  | ||||
|         eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters()) | ||||
|         self.assertTrue(all(tree_flatten(eq_tree))) | ||||
|  | ||||
|     def test_load_from_weights(self): | ||||
|         m = nn.Linear(2, 2) | ||||
|  | ||||
|         # Too few weights | ||||
|         weights = [("weight", mx.ones((2, 2)))] | ||||
|         with self.assertRaises(ValueError): | ||||
|             m.load_weights(weights) | ||||
|  | ||||
|         m.load_weights(weights, strict=False) | ||||
|         self.assertTrue(mx.array_equal(m.weight, weights[0][1])) | ||||
|  | ||||
|         # Wrong name | ||||
|         with self.assertRaises(ValueError): | ||||
|             m.load_weights([("weihgt", mx.ones((2, 2)))]) | ||||
|  | ||||
|         # Ok | ||||
|         m.load_weights([("weihgt", mx.ones((2, 2)))], strict=False) | ||||
|  | ||||
|         # Too many weights | ||||
|         with self.assertRaises(ValueError): | ||||
|             m.load_weights( | ||||
|                 [ | ||||
|                     ("weight", mx.ones((2, 2))), | ||||
|                     ("bias", mx.ones((2,))), | ||||
|                     ("bias2", mx.ones((2,))), | ||||
|                 ] | ||||
|             ) | ||||
|  | ||||
|         # Wrong shape | ||||
|         with self.assertRaises(ValueError): | ||||
|             m.load_weights( | ||||
|                 [ | ||||
|                     ("weight", mx.ones((2, 2))), | ||||
|                     ("bias", mx.ones((2, 1))), | ||||
|                 ] | ||||
|             ) | ||||
|  | ||||
|         # Wrong type | ||||
|         with self.assertRaises(ValueError): | ||||
|             m.load_weights( | ||||
|                 [ | ||||
|                     ("weight", mx.ones((2, 2))), | ||||
|                     ("bias", 3), | ||||
|                 ] | ||||
|             ) | ||||
|  | ||||
|  | ||||
| class TestLayers(mlx_tests.MLXTestCase): | ||||
|     def test_identity(self): | ||||
|         inputs = mx.zeros((10, 4)) | ||||
|         layer = nn.Identity() | ||||
| @@ -31,272 +136,6 @@ class TestNN(mlx_tests.MLXTestCase): | ||||
|         outputs = layer(inputs1, inputs2) | ||||
|         self.assertEqual(tuple(outputs.shape), (10, 6)) | ||||
|  | ||||
|     def test_cross_entropy(self): | ||||
|         logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) | ||||
|         targets = mx.array([0, 1]) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.cross_entropy(logits, targets, reduction="none") | ||||
|         expected_none = mx.array([0.0, 0.0]) | ||||
|         self.assertTrue(mx.array_equal(losses_none, expected_none)) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean") | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertEqual(losses_mean, expected_mean) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum") | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertEqual(losses_sum, expected_sum) | ||||
|  | ||||
|         # Test cases with weights and no label smoothing | ||||
|         logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) | ||||
|         targets = mx.array([0, 1]) | ||||
|         weights = mx.array([1.0, 2.0]) | ||||
|  | ||||
|         # Reduction 'none' | ||||
|         losses_none = nn.losses.cross_entropy( | ||||
|             logits, | ||||
|             targets, | ||||
|             weights=weights, | ||||
|             reduction="none", | ||||
|         ) | ||||
|         expected_none = mx.array([0.04858735, 0.0971747])  # Calculated losses | ||||
|         self.assertTrue( | ||||
|             np.allclose(losses_none, expected_none, atol=1e-5), | ||||
|             "Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]", | ||||
|         ) | ||||
|  | ||||
|         # Reduction 'mean' | ||||
|         losses_mean = nn.losses.cross_entropy( | ||||
|             logits, | ||||
|             targets, | ||||
|             weights=weights, | ||||
|             reduction="mean", | ||||
|         ) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertTrue( | ||||
|             np.allclose(losses_mean, expected_mean, atol=1e-5), | ||||
|             "Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]", | ||||
|         ) | ||||
|  | ||||
|         # Reduction 'sum' | ||||
|         losses_sum = nn.losses.cross_entropy( | ||||
|             logits, | ||||
|             targets, | ||||
|             weights=weights, | ||||
|             reduction="sum", | ||||
|         ) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertTrue( | ||||
|             np.allclose(losses_sum, expected_sum, atol=1e-5), | ||||
|             "Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]", | ||||
|         ) | ||||
|  | ||||
|         # Test case with equal weights and label smoothing > 0 | ||||
|         logits = mx.array( | ||||
|             [[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]] | ||||
|         ) | ||||
|         target = mx.array([2, 1, 0]) | ||||
|  | ||||
|         losses_none = nn.losses.cross_entropy( | ||||
|             logits, target, label_smoothing=0.3, reduction="none" | ||||
|         ) | ||||
|         expected_none = mx.array([1.29693, 1.38617, 1.48176]) | ||||
|         self.assertTrue( | ||||
|             mx.allclose(expected_none, losses_none), | ||||
|             "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'", | ||||
|         ) | ||||
|  | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         losses_mean = nn.losses.cross_entropy( | ||||
|             logits, target, label_smoothing=0.3, reduction="mean" | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose(losses_mean, expected_mean), | ||||
|             "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'", | ||||
|         ) | ||||
|  | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         losses_sum = nn.losses.cross_entropy( | ||||
|             logits, target, label_smoothing=0.3, reduction="sum" | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose(losses_sum, expected_sum), | ||||
|             "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'", | ||||
|         ) | ||||
|  | ||||
|     def test_l1_loss(self): | ||||
|         predictions = mx.array([0.5, 0.2, 0.9, 0.0]) | ||||
|         targets = mx.array([0.5, 0.2, 0.9, 0.0]) | ||||
|  | ||||
|         # Expected result | ||||
|         expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|  | ||||
|         losses = nn.losses.l1_loss(predictions, targets, reduction="none") | ||||
|         self.assertTrue( | ||||
|             mx.array_equal(losses, expected_none), | ||||
|             "Test failed for l1_loss --reduction='none'", | ||||
|         ) | ||||
|  | ||||
|         losses = nn.losses.l1_loss(predictions, targets, reduction="sum") | ||||
|         self.assertTrue(mx.array_equal(losses, expected_sum)) | ||||
|  | ||||
|         losses = nn.losses.l1_loss(predictions, targets, reduction="mean") | ||||
|         self.assertTrue(mx.array_equal(losses, expected_mean)) | ||||
|  | ||||
|     def test_mse_loss(self): | ||||
|         predictions = mx.array([0.5, 0.2, 0.9, 0.0]) | ||||
|         targets = mx.array([0.7, 0.1, 0.8, 0.2]) | ||||
|  | ||||
|         expected_none = mx.array([0.04, 0.01, 0.01, 0.04]) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.mse_loss(predictions, targets, reduction="none") | ||||
|         self.assertTrue( | ||||
|             np.allclose(losses_none, expected_none, 1e-5), | ||||
|             "Test case failed for mse_loss --reduction='none'", | ||||
|         ) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean") | ||||
|         self.assertEqual( | ||||
|             losses_mean, | ||||
|             expected_mean, | ||||
|             "Test case failed for mse_loss --reduction='mean'", | ||||
|         ) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum") | ||||
|         self.assertEqual( | ||||
|             losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'" | ||||
|         ) | ||||
|  | ||||
|     def test_smooth_l1_loss(self): | ||||
|         predictions = mx.array([1.5, 2.5, 0.5, 3.5]) | ||||
|         targets = mx.array([1.0, 2.0, 0.5, 2.5]) | ||||
|         beta = 1.0 | ||||
|  | ||||
|         # Expected results | ||||
|         expected_none = mx.array([0.125, 0.125, 0.0, 0.5]) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         loss_none = nn.losses.smooth_l1_loss( | ||||
|             predictions, targets, beta, reduction="none" | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.array_equal(loss_none, expected_none), | ||||
|             "Test case failed for smooth_l1_loss --reduction='none'", | ||||
|         ) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum") | ||||
|         self.assertEqual( | ||||
|             loss_sum, | ||||
|             expected_sum, | ||||
|             "Test case failed for smooth_l1_loss --reduction='sum'", | ||||
|         ) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         loss_mean = nn.losses.smooth_l1_loss( | ||||
|             predictions, targets, beta, reduction="mean" | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             loss_mean, | ||||
|             expected_mean, | ||||
|             "Test case failed for smooth_l1_loss --reduction='mean'", | ||||
|         ) | ||||
|  | ||||
|     def test_nll_loss(self): | ||||
|         logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) | ||||
|         targets = mx.array([0, 1]) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.nll_loss(logits, targets, reduction="none") | ||||
|         expected_none = mx.array([0.0, 0.0]) | ||||
|         self.assertTrue(mx.array_equal(losses_none, expected_none)) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean") | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertEqual(losses_mean, expected_mean) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum") | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertEqual(losses_sum, expected_sum) | ||||
|  | ||||
|     def test_kl_div_loss(self): | ||||
|         p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]])) | ||||
|         q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]])) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none") | ||||
|         expected_none = mx.array([0.0, 0.831777]) | ||||
|         self.assertTrue(mx.allclose(losses_none, expected_none)) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean") | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_mean, expected_mean)) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum") | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_sum, expected_sum)) | ||||
|  | ||||
|     def test_triplet_loss(self): | ||||
|         anchors = mx.array([[1, 2, 3], [1, 2, 3]]) | ||||
|         positives = mx.array([[4, 5, 6], [0, -1, 2]]) | ||||
|         negatives = mx.array([[7, 8, 9], [3, 2, 3]]) | ||||
|  | ||||
|         # Test with reduction 'none' | ||||
|         losses_none = nn.losses.triplet_loss( | ||||
|             anchors, positives, negatives, reduction="none" | ||||
|         ) | ||||
|         expected_none = mx.array([0, 2.31662]) | ||||
|         self.assertTrue(mx.allclose(losses_none, expected_none)) | ||||
|  | ||||
|         # Test with reduction 'mean' | ||||
|         losses_mean = nn.losses.triplet_loss( | ||||
|             anchors, positives, negatives, reduction="mean" | ||||
|         ) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_mean, expected_mean)) | ||||
|  | ||||
|         # Test with reduction 'sum' | ||||
|         losses_sum = nn.losses.triplet_loss( | ||||
|             anchors, positives, negatives, reduction="sum" | ||||
|         ) | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_sum, expected_sum)) | ||||
|  | ||||
|     def test_gelu(self): | ||||
|         inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414] | ||||
|  | ||||
|         # From: jax.nn.gelu(np.array(inputs), approximate=False) | ||||
|         expected = np.array( | ||||
|             [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383] | ||||
|         ) | ||||
|  | ||||
|         out = nn.GELU()(mx.array(inputs)) | ||||
|         self.assertTrue(np.allclose(out, expected)) | ||||
|  | ||||
|         # Crudely check the approximations | ||||
|         x = mx.arange(-6.0, 6.0, 12 / 100) | ||||
|         y = nn.gelu(x) | ||||
|         y_hat1 = nn.gelu_approx(x) | ||||
|         y_hat2 = nn.gelu_fast_approx(x) | ||||
|         self.assertLess(mx.abs(y - y_hat1).max(), 0.0003) | ||||
|         self.assertLess(mx.abs(y - y_hat2).max(), 0.02) | ||||
|  | ||||
|     def test_group_norm(self): | ||||
|         x = mx.arange(100, dtype=mx.float32) | ||||
|         x = x.reshape(1, 10, 10, 1) | ||||
| @@ -570,47 +409,24 @@ class TestNN(mlx_tests.MLXTestCase): | ||||
|         y2 = m(x) | ||||
|         self.assertTrue(mx.array_equal(y, y2)) | ||||
|  | ||||
|     def test_module_utilities(self): | ||||
|         m = nn.Sequential( | ||||
|             nn.Sequential(nn.Linear(2, 10), nn.relu), | ||||
|             nn.Sequential(nn.Linear(10, 10), nn.ReLU()), | ||||
|             nn.Linear(10, 1), | ||||
|             mx.sigmoid, | ||||
|     def test_gelu(self): | ||||
|         inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414] | ||||
|  | ||||
|         # From: jax.nn.gelu(np.array(inputs), approximate=False) | ||||
|         expected = np.array( | ||||
|             [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383] | ||||
|         ) | ||||
|  | ||||
|         children = m.children() | ||||
|         self.assertTrue(isinstance(children, dict)) | ||||
|         self.assertEqual(len(children), 1) | ||||
|         self.assertTrue(isinstance(children["layers"], list)) | ||||
|         self.assertEqual(len(children["layers"]), 4) | ||||
|         self.assertEqual(children["layers"][3], {}) | ||||
|         flat_children = tree_flatten(children, is_leaf=nn.Module.is_module) | ||||
|         self.assertEqual(len(flat_children), 3) | ||||
|         out = nn.GELU()(mx.array(inputs)) | ||||
|         self.assertTrue(np.allclose(out, expected)) | ||||
|  | ||||
|         leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module) | ||||
|         self.assertEqual(len(leaves), 4) | ||||
|         self.assertEqual(leaves[0][0], "layers.0.layers.0") | ||||
|         self.assertEqual(leaves[1][0], "layers.1.layers.0") | ||||
|         self.assertEqual(leaves[2][0], "layers.1.layers.1") | ||||
|         self.assertEqual(leaves[3][0], "layers.2") | ||||
|         self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) | ||||
|         self.assertTrue(leaves[1][1] is m.layers[1].layers[0]) | ||||
|         self.assertTrue(leaves[2][1] is m.layers[1].layers[1]) | ||||
|         self.assertTrue(leaves[3][1] is m.layers[2]) | ||||
|  | ||||
|         m.eval() | ||||
|  | ||||
|         def assert_not_training(k, m): | ||||
|             self.assertFalse(m.training) | ||||
|  | ||||
|         m.apply_to_modules(assert_not_training) | ||||
|  | ||||
|         m.train() | ||||
|  | ||||
|         def assert_training(k, m): | ||||
|             self.assertTrue(m.training) | ||||
|  | ||||
|         m.apply_to_modules(assert_training) | ||||
|         # Crudely check the approximations | ||||
|         x = mx.arange(-6.0, 6.0, 12 / 100) | ||||
|         y = nn.gelu(x) | ||||
|         y_hat1 = nn.gelu_approx(x) | ||||
|         y_hat2 = nn.gelu_fast_approx(x) | ||||
|         self.assertLess(mx.abs(y - y_hat1).max(), 0.0003) | ||||
|         self.assertLess(mx.abs(y - y_hat2).max(), 0.02) | ||||
|  | ||||
|     def test_sin_pe(self): | ||||
|         m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01) | ||||
| @@ -623,21 +439,6 @@ class TestNN(mlx_tests.MLXTestCase): | ||||
|             mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5 | ||||
|         ) | ||||
|  | ||||
|     def test_io(self): | ||||
|         def make_model(): | ||||
|             return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) | ||||
|  | ||||
|         m = make_model() | ||||
|         tdir = tempfile.TemporaryDirectory() | ||||
|         file = os.path.join(tdir.name, "model.npz") | ||||
|         m.save_weights(file) | ||||
|         m_load = make_model() | ||||
|         m_load.load_weights(file) | ||||
|         tdir.cleanup() | ||||
|  | ||||
|         eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters()) | ||||
|         self.assertTrue(all(tree_flatten(eq_tree))) | ||||
|  | ||||
|     def test_relu(self): | ||||
|         x = mx.array([1.0, -1.0, 0.0]) | ||||
|         y = nn.relu(x) | ||||
| @@ -787,24 +588,6 @@ class TestNN(mlx_tests.MLXTestCase): | ||||
|         y = alibi(x.astype(mx.float16)) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_hinge_loss(self): | ||||
|         inputs = mx.ones((2, 4)) | ||||
|         targets = mx.zeros((2, 4)) | ||||
|         loss = nn.losses.hinge_loss(inputs, targets, reduction="mean") | ||||
|         self.assertEqual(loss, 1.0) | ||||
|  | ||||
|     def test_huber_loss(self): | ||||
|         inputs = mx.ones((2, 4)) | ||||
|         targets = mx.zeros((2, 4)) | ||||
|         loss = nn.losses.huber_loss(inputs, targets, reduction="mean") | ||||
|         self.assertEqual(loss, 0.5) | ||||
|  | ||||
|     def test_log_cosh_loss(self): | ||||
|         inputs = mx.ones((2, 4)) | ||||
|         targets = mx.zeros((2, 4)) | ||||
|         loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") | ||||
|         self.assertAlmostEqual(loss.item(), 0.433781, places=6) | ||||
|  | ||||
|     def test_dropout(self): | ||||
|         x = mx.ones((2, 4)) | ||||
|         y = nn.Dropout(0.5)(x) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun