diff --git a/docs/src/_templates/module-base-class.rst b/docs/src/_templates/module-base-class.rst new file mode 100644 index 000000000..08371fb01 --- /dev/null +++ b/docs/src/_templates/module-base-class.rst @@ -0,0 +1,33 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. add toctree option to make autodoc generate the pages + +.. autoclass:: {{ objname }} + + {% block attributes %} + {% if attributes %} + .. rubric:: Attributes + + .. autosummary:: + :toctree: . + {% for item in attributes %} + ~{{ fullname }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block methods %} + {% if methods %} + .. rubric:: Methods + + .. autosummary:: + :toctree: . + {% for item in methods %} + {%- if item not in inherited_members and item != '__init__' %} + ~{{ fullname }}.{{ item }} + {%- endif -%} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 4c9868171..496c27823 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -170,14 +170,13 @@ In detail: :meth:`mlx.core.value_and_grad` .. autosummary:: - :recursive: :toctree: _autosummary value_and_grad - Module .. toctree:: + nn/module nn/layers nn/functions nn/losses diff --git a/docs/src/python/nn/module.rst b/docs/src/python/nn/module.rst new file mode 100644 index 000000000..042a88028 --- /dev/null +++ b/docs/src/python/nn/module.rst @@ -0,0 +1,36 @@ +Module +====== + +.. currentmodule:: mlx.nn + +.. autoclass:: Module + + .. rubric:: Attributes + + .. autosummary:: + :toctree: _autosummary + + Module.training + + .. rubric:: Methods + + .. autosummary:: + :toctree: _autosummary + + Module.apply + Module.apply_to_modules + Module.children + Module.eval + Module.filter_and_map + Module.freeze + Module.leaf_modules + Module.load_weights + Module.modules + Module.named_modules + Module.parameters + Module.save_weights + Module.train + Module.trainable_parameters + Module.unfreeze + Module.update + Module.update_modules diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index dcf079457..646f5f2dc 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import textwrap -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union import mlx.core as mx from mlx.utils import tree_flatten, tree_unflatten @@ -61,6 +61,7 @@ class Module(dict): @property def training(self): + """Boolean indicating if the model is in training mode.""" return self._training def _extra_repr(self): @@ -87,15 +88,83 @@ class Module(dict): def __setattr__(self, key: str, val: Any): self[key] = val - def load_weights(self, file: str): + def load_weights( + self, + file_or_weights: Union[str, List[Tuple[str, mx.array]]], + strict: bool = True, + ): """ - Load and update the model's weights from a `.npz` file. + Update the model's weights from a ``.npz`` or a list. + + Args: + file_or_weights (str or list(tuple(str, mx.array))): The path to + the weights ``.npz`` file or a list of pairs of parameter names + and arrays. + strict (bool, optional): If ``True`` then checks that the provided + weights exactly match the parameters of the model. Otherwise, + only the weights actually contained in the model are loaded and + shapes are not checked. Default: ``True``. + + Example: + + .. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + model = nn.Linear(10, 10) + + # Load from file + model.load_weights("weights.npz") + + # Load from list + weights = [ + ("weight", mx.random.uniform(shape=(10, 10))), + ("bias", mx.zeros((10,))), + ] + model.load_weights(weights) + + # Missing weight + weights = [ + ("weight", mx.random.uniform(shape=(10, 10))), + ] + + # Raises a ValueError exception + model.load_weights(weights) + + # Ok, only updates the weight but not the bias + model.load_weights(weights, strict=False) """ - self.update(tree_unflatten(list(mx.load(file).items()))) + weights = file_or_weights + if isinstance(weights, str): + weights = list(mx.load(weights).items()) + + if strict: + new_weights = dict(weights) + curr_weights = dict(tree_flatten(self.parameters())) + if extras := (new_weights.keys() - curr_weights.keys()): + extras = " ".join(extras) + raise ValueError(f"Received parameters not in model: {extras}.") + if missing := (curr_weights.keys() - new_weights.keys()): + missing = " ".join(missing) + raise ValueError(f"Missing parameters: {missing}.") + for k, v in curr_weights.items(): + v_new = new_weights[k] + if not isinstance(v_new, mx.array): + raise ValueError( + "Expected mx.array but received " + f"{type(v_new)} for parameter {k}" + ) + if v_new.shape != v.shape: + raise ValueError( + f"Expected shape {v.shape} but received " + f" shape {v_new.shape} for parameter {k}" + ) + + self.update(tree_unflatten(weights)) def save_weights(self, file: str): """ - Save the model's weights to a `.npz` file. + Save the model's weights to a ``.npz`` file. """ mx.savez(file, **dict(tree_flatten(self.parameters()))) @@ -351,23 +420,26 @@ class Module(dict): """Freeze the Module's parameters or some of them. Freezing a parameter means not computing gradients for it. - This function is idempotent ie freezing a frozen model is a noop. + This function is idempotent i.e. freezing a frozen model is a no-op. - For instance to only train the attention parameters from a transformer: + Example: + For instance to only train the attention parameters from a Transformer: - model = ... - model.freeze() - model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None) + .. code-block:: python + + model = nn.Transformer() + model.freeze() + model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None) Args: recurse (bool, optional): If True then freeze the parameters of the - submodules as well (default: True). + submodules as well. Default: ``True``. keys (str or list[str], optional): If provided then only these parameters will be frozen otherwise all the parameters of a module. For instance freeze all biases by calling ``module.freeze(keys="bias")``. - strict (bool, optional): If set to True validate that the passed keys exist - (default: False). + strict (bool, optional): If set to ``True`` validate that the passed keys exist. + Default: ``False``. """ def _freeze_impl(_, m): @@ -401,21 +473,25 @@ class Module(dict): This function is idempotent ie unfreezing a model that is not frozen is a noop. - For instance to only train the biases one can do: + Example: - model = ... - model.freeze() - model.unfreeze(keys="bias") + For instance to only train the biases of a Transformer one can do: + + .. code-block:: python + + model = nn.Transformer() + model.freeze() + model.unfreeze(keys="bias") Args: recurse (bool, optional): If True then unfreeze the parameters of the - submodules as well (default: True). + submodules as well. Default: ``True``. keys (str or list[str], optional): If provided then only these parameters will be unfrozen otherwise all the parameters of a module. For instance unfreeze all biases by calling ``module.unfreeze(keys="bias")``. - strict (bool, optional): If set to True validate that the passed keys exist - (default: False). + strict (bool, optional): If set to ``True`` validate that the passed keys exist. + Default: ``False``. """ def _unfreeze_impl(_, m): @@ -432,10 +508,25 @@ class Module(dict): _unfreeze_impl("", self) def train(self, mode: bool = True): + """Set the model in or out of training mode. + + Training mode only applies to certain layers. For example + :obj:`Dropout` applies a random mask in training mode, but is the + identity in evaluation mode. + + Args: + mode (bool): Indicate if the model should be in training or + evaluation mode. Default: ``True``. + """ + def _set_train(_, m): m._training = mode self.apply_to_modules(_set_train) def eval(self): + """Set the model to evaluation mode. + + See :func:`train`. + """ self.train(False) diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py new file mode 100644 index 000000000..706e336c2 --- /dev/null +++ b/python/tests/test_losses.py @@ -0,0 +1,279 @@ +# Copyright © 2023 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx.nn as nn +import mlx_tests +import numpy as np + + +class TestLosses(mlx_tests.MLXTestCase): + def test_cross_entropy(self): + logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) + targets = mx.array([0, 1]) + + # Test with reduction 'none' + losses_none = nn.losses.cross_entropy(logits, targets, reduction="none") + expected_none = mx.array([0.0, 0.0]) + self.assertTrue(mx.array_equal(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertEqual(losses_mean, expected_mean) + + # Test with reduction 'sum' + losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertEqual(losses_sum, expected_sum) + + # Test cases with weights and no label smoothing + logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) + targets = mx.array([0, 1]) + weights = mx.array([1.0, 2.0]) + + # Reduction 'none' + losses_none = nn.losses.cross_entropy( + logits, + targets, + weights=weights, + reduction="none", + ) + expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses + self.assertTrue( + np.allclose(losses_none, expected_none, atol=1e-5), + "Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]", + ) + + # Reduction 'mean' + losses_mean = nn.losses.cross_entropy( + logits, + targets, + weights=weights, + reduction="mean", + ) + expected_mean = mx.mean(expected_none) + self.assertTrue( + np.allclose(losses_mean, expected_mean, atol=1e-5), + "Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]", + ) + + # Reduction 'sum' + losses_sum = nn.losses.cross_entropy( + logits, + targets, + weights=weights, + reduction="sum", + ) + expected_sum = mx.sum(expected_none) + self.assertTrue( + np.allclose(losses_sum, expected_sum, atol=1e-5), + "Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]", + ) + + # Test case with equal weights and label smoothing > 0 + logits = mx.array( + [[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]] + ) + target = mx.array([2, 1, 0]) + + losses_none = nn.losses.cross_entropy( + logits, target, label_smoothing=0.3, reduction="none" + ) + expected_none = mx.array([1.29693, 1.38617, 1.48176]) + self.assertTrue( + mx.allclose(expected_none, losses_none), + "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'", + ) + + expected_mean = mx.mean(expected_none) + losses_mean = nn.losses.cross_entropy( + logits, target, label_smoothing=0.3, reduction="mean" + ) + self.assertTrue( + mx.allclose(losses_mean, expected_mean), + "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'", + ) + + expected_sum = mx.sum(expected_none) + losses_sum = nn.losses.cross_entropy( + logits, target, label_smoothing=0.3, reduction="sum" + ) + self.assertTrue( + mx.allclose(losses_sum, expected_sum), + "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'", + ) + + def test_l1_loss(self): + predictions = mx.array([0.5, 0.2, 0.9, 0.0]) + targets = mx.array([0.5, 0.2, 0.9, 0.0]) + + # Expected result + expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32) + expected_sum = mx.sum(expected_none) + expected_mean = mx.mean(expected_none) + + losses = nn.losses.l1_loss(predictions, targets, reduction="none") + self.assertTrue( + mx.array_equal(losses, expected_none), + "Test failed for l1_loss --reduction='none'", + ) + + losses = nn.losses.l1_loss(predictions, targets, reduction="sum") + self.assertTrue(mx.array_equal(losses, expected_sum)) + + losses = nn.losses.l1_loss(predictions, targets, reduction="mean") + self.assertTrue(mx.array_equal(losses, expected_mean)) + + def test_mse_loss(self): + predictions = mx.array([0.5, 0.2, 0.9, 0.0]) + targets = mx.array([0.7, 0.1, 0.8, 0.2]) + + expected_none = mx.array([0.04, 0.01, 0.01, 0.04]) + expected_mean = mx.mean(expected_none) + expected_sum = mx.sum(expected_none) + + # Test with reduction 'none' + losses_none = nn.losses.mse_loss(predictions, targets, reduction="none") + self.assertTrue( + np.allclose(losses_none, expected_none, 1e-5), + "Test case failed for mse_loss --reduction='none'", + ) + + # Test with reduction 'mean' + losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean") + self.assertEqual( + losses_mean, + expected_mean, + "Test case failed for mse_loss --reduction='mean'", + ) + + # Test with reduction 'sum' + losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum") + self.assertEqual( + losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'" + ) + + def test_smooth_l1_loss(self): + predictions = mx.array([1.5, 2.5, 0.5, 3.5]) + targets = mx.array([1.0, 2.0, 0.5, 2.5]) + beta = 1.0 + + # Expected results + expected_none = mx.array([0.125, 0.125, 0.0, 0.5]) + expected_sum = mx.sum(expected_none) + expected_mean = mx.mean(expected_none) + + # Test with reduction 'none' + loss_none = nn.losses.smooth_l1_loss( + predictions, targets, beta, reduction="none" + ) + self.assertTrue( + mx.array_equal(loss_none, expected_none), + "Test case failed for smooth_l1_loss --reduction='none'", + ) + + # Test with reduction 'sum' + loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum") + self.assertEqual( + loss_sum, + expected_sum, + "Test case failed for smooth_l1_loss --reduction='sum'", + ) + + # Test with reduction 'mean' + loss_mean = nn.losses.smooth_l1_loss( + predictions, targets, beta, reduction="mean" + ) + self.assertEqual( + loss_mean, + expected_mean, + "Test case failed for smooth_l1_loss --reduction='mean'", + ) + + def test_nll_loss(self): + logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) + targets = mx.array([0, 1]) + + # Test with reduction 'none' + losses_none = nn.losses.nll_loss(logits, targets, reduction="none") + expected_none = mx.array([0.0, 0.0]) + self.assertTrue(mx.array_equal(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertEqual(losses_mean, expected_mean) + + # Test with reduction 'sum' + losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertEqual(losses_sum, expected_sum) + + def test_kl_div_loss(self): + p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]])) + q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]])) + + # Test with reduction 'none' + losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none") + expected_none = mx.array([0.0, 0.831777]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + + def test_triplet_loss(self): + anchors = mx.array([[1, 2, 3], [1, 2, 3]]) + positives = mx.array([[4, 5, 6], [0, -1, 2]]) + negatives = mx.array([[7, 8, 9], [3, 2, 3]]) + + # Test with reduction 'none' + losses_none = nn.losses.triplet_loss( + anchors, positives, negatives, reduction="none" + ) + expected_none = mx.array([0, 2.31662]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.triplet_loss( + anchors, positives, negatives, reduction="mean" + ) + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.triplet_loss( + anchors, positives, negatives, reduction="sum" + ) + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + + def test_hinge_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.hinge_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 1.0) + + def test_huber_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.huber_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 0.5) + + def test_log_cosh_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") + self.assertAlmostEqual(loss.item(), 0.433781, places=6) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index a6ce87f34..8529f33a6 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -11,7 +11,112 @@ import numpy as np from mlx.utils import tree_flatten, tree_map, tree_unflatten -class TestNN(mlx_tests.MLXTestCase): +class TestBase(mlx_tests.MLXTestCase): + def test_module_utilities(self): + m = nn.Sequential( + nn.Sequential(nn.Linear(2, 10), nn.relu), + nn.Sequential(nn.Linear(10, 10), nn.ReLU()), + nn.Linear(10, 1), + mx.sigmoid, + ) + + children = m.children() + self.assertTrue(isinstance(children, dict)) + self.assertEqual(len(children), 1) + self.assertTrue(isinstance(children["layers"], list)) + self.assertEqual(len(children["layers"]), 4) + self.assertEqual(children["layers"][3], {}) + flat_children = tree_flatten(children, is_leaf=nn.Module.is_module) + self.assertEqual(len(flat_children), 3) + + leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module) + self.assertEqual(len(leaves), 4) + self.assertEqual(leaves[0][0], "layers.0.layers.0") + self.assertEqual(leaves[1][0], "layers.1.layers.0") + self.assertEqual(leaves[2][0], "layers.1.layers.1") + self.assertEqual(leaves[3][0], "layers.2") + self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) + self.assertTrue(leaves[1][1] is m.layers[1].layers[0]) + self.assertTrue(leaves[2][1] is m.layers[1].layers[1]) + self.assertTrue(leaves[3][1] is m.layers[2]) + + m.eval() + + def assert_not_training(k, m): + self.assertFalse(m.training) + + m.apply_to_modules(assert_not_training) + + m.train() + + def assert_training(k, m): + self.assertTrue(m.training) + + m.apply_to_modules(assert_training) + + def test_io(self): + def make_model(): + return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) + + m = make_model() + tdir = tempfile.TemporaryDirectory() + file = os.path.join(tdir.name, "model.npz") + m.save_weights(file) + m_load = make_model() + m_load.load_weights(file) + tdir.cleanup() + + eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters()) + self.assertTrue(all(tree_flatten(eq_tree))) + + def test_load_from_weights(self): + m = nn.Linear(2, 2) + + # Too few weights + weights = [("weight", mx.ones((2, 2)))] + with self.assertRaises(ValueError): + m.load_weights(weights) + + m.load_weights(weights, strict=False) + self.assertTrue(mx.array_equal(m.weight, weights[0][1])) + + # Wrong name + with self.assertRaises(ValueError): + m.load_weights([("weihgt", mx.ones((2, 2)))]) + + # Ok + m.load_weights([("weihgt", mx.ones((2, 2)))], strict=False) + + # Too many weights + with self.assertRaises(ValueError): + m.load_weights( + [ + ("weight", mx.ones((2, 2))), + ("bias", mx.ones((2,))), + ("bias2", mx.ones((2,))), + ] + ) + + # Wrong shape + with self.assertRaises(ValueError): + m.load_weights( + [ + ("weight", mx.ones((2, 2))), + ("bias", mx.ones((2, 1))), + ] + ) + + # Wrong type + with self.assertRaises(ValueError): + m.load_weights( + [ + ("weight", mx.ones((2, 2))), + ("bias", 3), + ] + ) + + +class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): inputs = mx.zeros((10, 4)) layer = nn.Identity() @@ -31,272 +136,6 @@ class TestNN(mlx_tests.MLXTestCase): outputs = layer(inputs1, inputs2) self.assertEqual(tuple(outputs.shape), (10, 6)) - def test_cross_entropy(self): - logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) - targets = mx.array([0, 1]) - - # Test with reduction 'none' - losses_none = nn.losses.cross_entropy(logits, targets, reduction="none") - expected_none = mx.array([0.0, 0.0]) - self.assertTrue(mx.array_equal(losses_none, expected_none)) - - # Test with reduction 'mean' - losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean") - expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) - - # Test with reduction 'sum' - losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum") - expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) - - # Test cases with weights and no label smoothing - logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) - targets = mx.array([0, 1]) - weights = mx.array([1.0, 2.0]) - - # Reduction 'none' - losses_none = nn.losses.cross_entropy( - logits, - targets, - weights=weights, - reduction="none", - ) - expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses - self.assertTrue( - np.allclose(losses_none, expected_none, atol=1e-5), - "Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]", - ) - - # Reduction 'mean' - losses_mean = nn.losses.cross_entropy( - logits, - targets, - weights=weights, - reduction="mean", - ) - expected_mean = mx.mean(expected_none) - self.assertTrue( - np.allclose(losses_mean, expected_mean, atol=1e-5), - "Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]", - ) - - # Reduction 'sum' - losses_sum = nn.losses.cross_entropy( - logits, - targets, - weights=weights, - reduction="sum", - ) - expected_sum = mx.sum(expected_none) - self.assertTrue( - np.allclose(losses_sum, expected_sum, atol=1e-5), - "Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]", - ) - - # Test case with equal weights and label smoothing > 0 - logits = mx.array( - [[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]] - ) - target = mx.array([2, 1, 0]) - - losses_none = nn.losses.cross_entropy( - logits, target, label_smoothing=0.3, reduction="none" - ) - expected_none = mx.array([1.29693, 1.38617, 1.48176]) - self.assertTrue( - mx.allclose(expected_none, losses_none), - "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'", - ) - - expected_mean = mx.mean(expected_none) - losses_mean = nn.losses.cross_entropy( - logits, target, label_smoothing=0.3, reduction="mean" - ) - self.assertTrue( - mx.allclose(losses_mean, expected_mean), - "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'", - ) - - expected_sum = mx.sum(expected_none) - losses_sum = nn.losses.cross_entropy( - logits, target, label_smoothing=0.3, reduction="sum" - ) - self.assertTrue( - mx.allclose(losses_sum, expected_sum), - "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'", - ) - - def test_l1_loss(self): - predictions = mx.array([0.5, 0.2, 0.9, 0.0]) - targets = mx.array([0.5, 0.2, 0.9, 0.0]) - - # Expected result - expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32) - expected_sum = mx.sum(expected_none) - expected_mean = mx.mean(expected_none) - - losses = nn.losses.l1_loss(predictions, targets, reduction="none") - self.assertTrue( - mx.array_equal(losses, expected_none), - "Test failed for l1_loss --reduction='none'", - ) - - losses = nn.losses.l1_loss(predictions, targets, reduction="sum") - self.assertTrue(mx.array_equal(losses, expected_sum)) - - losses = nn.losses.l1_loss(predictions, targets, reduction="mean") - self.assertTrue(mx.array_equal(losses, expected_mean)) - - def test_mse_loss(self): - predictions = mx.array([0.5, 0.2, 0.9, 0.0]) - targets = mx.array([0.7, 0.1, 0.8, 0.2]) - - expected_none = mx.array([0.04, 0.01, 0.01, 0.04]) - expected_mean = mx.mean(expected_none) - expected_sum = mx.sum(expected_none) - - # Test with reduction 'none' - losses_none = nn.losses.mse_loss(predictions, targets, reduction="none") - self.assertTrue( - np.allclose(losses_none, expected_none, 1e-5), - "Test case failed for mse_loss --reduction='none'", - ) - - # Test with reduction 'mean' - losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean") - self.assertEqual( - losses_mean, - expected_mean, - "Test case failed for mse_loss --reduction='mean'", - ) - - # Test with reduction 'sum' - losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum") - self.assertEqual( - losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'" - ) - - def test_smooth_l1_loss(self): - predictions = mx.array([1.5, 2.5, 0.5, 3.5]) - targets = mx.array([1.0, 2.0, 0.5, 2.5]) - beta = 1.0 - - # Expected results - expected_none = mx.array([0.125, 0.125, 0.0, 0.5]) - expected_sum = mx.sum(expected_none) - expected_mean = mx.mean(expected_none) - - # Test with reduction 'none' - loss_none = nn.losses.smooth_l1_loss( - predictions, targets, beta, reduction="none" - ) - self.assertTrue( - mx.array_equal(loss_none, expected_none), - "Test case failed for smooth_l1_loss --reduction='none'", - ) - - # Test with reduction 'sum' - loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum") - self.assertEqual( - loss_sum, - expected_sum, - "Test case failed for smooth_l1_loss --reduction='sum'", - ) - - # Test with reduction 'mean' - loss_mean = nn.losses.smooth_l1_loss( - predictions, targets, beta, reduction="mean" - ) - self.assertEqual( - loss_mean, - expected_mean, - "Test case failed for smooth_l1_loss --reduction='mean'", - ) - - def test_nll_loss(self): - logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) - targets = mx.array([0, 1]) - - # Test with reduction 'none' - losses_none = nn.losses.nll_loss(logits, targets, reduction="none") - expected_none = mx.array([0.0, 0.0]) - self.assertTrue(mx.array_equal(losses_none, expected_none)) - - # Test with reduction 'mean' - losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean") - expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) - - # Test with reduction 'sum' - losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum") - expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) - - def test_kl_div_loss(self): - p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]])) - q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]])) - - # Test with reduction 'none' - losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none") - expected_none = mx.array([0.0, 0.831777]) - self.assertTrue(mx.allclose(losses_none, expected_none)) - - # Test with reduction 'mean' - losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean") - expected_mean = mx.mean(expected_none) - self.assertTrue(mx.allclose(losses_mean, expected_mean)) - - # Test with reduction 'sum' - losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum") - expected_sum = mx.sum(expected_none) - self.assertTrue(mx.allclose(losses_sum, expected_sum)) - - def test_triplet_loss(self): - anchors = mx.array([[1, 2, 3], [1, 2, 3]]) - positives = mx.array([[4, 5, 6], [0, -1, 2]]) - negatives = mx.array([[7, 8, 9], [3, 2, 3]]) - - # Test with reduction 'none' - losses_none = nn.losses.triplet_loss( - anchors, positives, negatives, reduction="none" - ) - expected_none = mx.array([0, 2.31662]) - self.assertTrue(mx.allclose(losses_none, expected_none)) - - # Test with reduction 'mean' - losses_mean = nn.losses.triplet_loss( - anchors, positives, negatives, reduction="mean" - ) - expected_mean = mx.mean(expected_none) - self.assertTrue(mx.allclose(losses_mean, expected_mean)) - - # Test with reduction 'sum' - losses_sum = nn.losses.triplet_loss( - anchors, positives, negatives, reduction="sum" - ) - expected_sum = mx.sum(expected_none) - self.assertTrue(mx.allclose(losses_sum, expected_sum)) - - def test_gelu(self): - inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414] - - # From: jax.nn.gelu(np.array(inputs), approximate=False) - expected = np.array( - [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383] - ) - - out = nn.GELU()(mx.array(inputs)) - self.assertTrue(np.allclose(out, expected)) - - # Crudely check the approximations - x = mx.arange(-6.0, 6.0, 12 / 100) - y = nn.gelu(x) - y_hat1 = nn.gelu_approx(x) - y_hat2 = nn.gelu_fast_approx(x) - self.assertLess(mx.abs(y - y_hat1).max(), 0.0003) - self.assertLess(mx.abs(y - y_hat2).max(), 0.02) - def test_group_norm(self): x = mx.arange(100, dtype=mx.float32) x = x.reshape(1, 10, 10, 1) @@ -570,47 +409,24 @@ class TestNN(mlx_tests.MLXTestCase): y2 = m(x) self.assertTrue(mx.array_equal(y, y2)) - def test_module_utilities(self): - m = nn.Sequential( - nn.Sequential(nn.Linear(2, 10), nn.relu), - nn.Sequential(nn.Linear(10, 10), nn.ReLU()), - nn.Linear(10, 1), - mx.sigmoid, + def test_gelu(self): + inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414] + + # From: jax.nn.gelu(np.array(inputs), approximate=False) + expected = np.array( + [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383] ) - children = m.children() - self.assertTrue(isinstance(children, dict)) - self.assertEqual(len(children), 1) - self.assertTrue(isinstance(children["layers"], list)) - self.assertEqual(len(children["layers"]), 4) - self.assertEqual(children["layers"][3], {}) - flat_children = tree_flatten(children, is_leaf=nn.Module.is_module) - self.assertEqual(len(flat_children), 3) + out = nn.GELU()(mx.array(inputs)) + self.assertTrue(np.allclose(out, expected)) - leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module) - self.assertEqual(len(leaves), 4) - self.assertEqual(leaves[0][0], "layers.0.layers.0") - self.assertEqual(leaves[1][0], "layers.1.layers.0") - self.assertEqual(leaves[2][0], "layers.1.layers.1") - self.assertEqual(leaves[3][0], "layers.2") - self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) - self.assertTrue(leaves[1][1] is m.layers[1].layers[0]) - self.assertTrue(leaves[2][1] is m.layers[1].layers[1]) - self.assertTrue(leaves[3][1] is m.layers[2]) - - m.eval() - - def assert_not_training(k, m): - self.assertFalse(m.training) - - m.apply_to_modules(assert_not_training) - - m.train() - - def assert_training(k, m): - self.assertTrue(m.training) - - m.apply_to_modules(assert_training) + # Crudely check the approximations + x = mx.arange(-6.0, 6.0, 12 / 100) + y = nn.gelu(x) + y_hat1 = nn.gelu_approx(x) + y_hat2 = nn.gelu_fast_approx(x) + self.assertLess(mx.abs(y - y_hat1).max(), 0.0003) + self.assertLess(mx.abs(y - y_hat2).max(), 0.02) def test_sin_pe(self): m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01) @@ -623,21 +439,6 @@ class TestNN(mlx_tests.MLXTestCase): mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5 ) - def test_io(self): - def make_model(): - return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) - - m = make_model() - tdir = tempfile.TemporaryDirectory() - file = os.path.join(tdir.name, "model.npz") - m.save_weights(file) - m_load = make_model() - m_load.load_weights(file) - tdir.cleanup() - - eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters()) - self.assertTrue(all(tree_flatten(eq_tree))) - def test_relu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.relu(x) @@ -787,24 +588,6 @@ class TestNN(mlx_tests.MLXTestCase): y = alibi(x.astype(mx.float16)) self.assertTrue(y.dtype, mx.float16) - def test_hinge_loss(self): - inputs = mx.ones((2, 4)) - targets = mx.zeros((2, 4)) - loss = nn.losses.hinge_loss(inputs, targets, reduction="mean") - self.assertEqual(loss, 1.0) - - def test_huber_loss(self): - inputs = mx.ones((2, 4)) - targets = mx.zeros((2, 4)) - loss = nn.losses.huber_loss(inputs, targets, reduction="mean") - self.assertEqual(loss, 0.5) - - def test_log_cosh_loss(self): - inputs = mx.ones((2, 4)) - targets = mx.zeros((2, 4)) - loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") - self.assertAlmostEqual(loss.item(), 0.433781, places=6) - def test_dropout(self): x = mx.ones((2, 4)) y = nn.Dropout(0.5)(x)