mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Module checks the weight on load_weights
(#337)
* update module to check weights on load, also fix docs and reorganize tests * nits + rebase * a few more docs updates for Module * use manual module file * comment
This commit is contained in:
parent
0782a4573a
commit
dff4a3833f
33
docs/src/_templates/module-base-class.rst
Normal file
33
docs/src/_templates/module-base-class.rst
Normal file
@ -0,0 +1,33 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. add toctree option to make autodoc generate the pages
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{% block attributes %}
|
||||
{% if attributes %}
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: .
|
||||
{% for item in attributes %}
|
||||
~{{ fullname }}.{{ item }}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block methods %}
|
||||
{% if methods %}
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: .
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != '__init__' %}
|
||||
~{{ fullname }}.{{ item }}
|
||||
{%- endif -%}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
@ -170,14 +170,13 @@ In detail:
|
||||
:meth:`mlx.core.value_and_grad`
|
||||
|
||||
.. autosummary::
|
||||
:recursive:
|
||||
:toctree: _autosummary
|
||||
|
||||
value_and_grad
|
||||
Module
|
||||
|
||||
.. toctree::
|
||||
|
||||
nn/module
|
||||
nn/layers
|
||||
nn/functions
|
||||
nn/losses
|
||||
|
36
docs/src/python/nn/module.rst
Normal file
36
docs/src/python/nn/module.rst
Normal file
@ -0,0 +1,36 @@
|
||||
Module
|
||||
======
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
.. autoclass:: Module
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Module.training
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Module.apply
|
||||
Module.apply_to_modules
|
||||
Module.children
|
||||
Module.eval
|
||||
Module.filter_and_map
|
||||
Module.freeze
|
||||
Module.leaf_modules
|
||||
Module.load_weights
|
||||
Module.modules
|
||||
Module.named_modules
|
||||
Module.parameters
|
||||
Module.save_weights
|
||||
Module.train
|
||||
Module.trainable_parameters
|
||||
Module.unfreeze
|
||||
Module.update
|
||||
Module.update_modules
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import textwrap
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
@ -61,6 +61,7 @@ class Module(dict):
|
||||
|
||||
@property
|
||||
def training(self):
|
||||
"""Boolean indicating if the model is in training mode."""
|
||||
return self._training
|
||||
|
||||
def _extra_repr(self):
|
||||
@ -87,15 +88,83 @@ class Module(dict):
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
self[key] = val
|
||||
|
||||
def load_weights(self, file: str):
|
||||
def load_weights(
|
||||
self,
|
||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||
strict: bool = True,
|
||||
):
|
||||
"""
|
||||
Load and update the model's weights from a `.npz` file.
|
||||
Update the model's weights from a ``.npz`` or a list.
|
||||
|
||||
Args:
|
||||
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
||||
the weights ``.npz`` file or a list of pairs of parameter names
|
||||
and arrays.
|
||||
strict (bool, optional): If ``True`` then checks that the provided
|
||||
weights exactly match the parameters of the model. Otherwise,
|
||||
only the weights actually contained in the model are loaded and
|
||||
shapes are not checked. Default: ``True``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
model = nn.Linear(10, 10)
|
||||
|
||||
# Load from file
|
||||
model.load_weights("weights.npz")
|
||||
|
||||
# Load from list
|
||||
weights = [
|
||||
("weight", mx.random.uniform(shape=(10, 10))),
|
||||
("bias", mx.zeros((10,))),
|
||||
]
|
||||
model.load_weights(weights)
|
||||
|
||||
# Missing weight
|
||||
weights = [
|
||||
("weight", mx.random.uniform(shape=(10, 10))),
|
||||
]
|
||||
|
||||
# Raises a ValueError exception
|
||||
model.load_weights(weights)
|
||||
|
||||
# Ok, only updates the weight but not the bias
|
||||
model.load_weights(weights, strict=False)
|
||||
"""
|
||||
self.update(tree_unflatten(list(mx.load(file).items())))
|
||||
weights = file_or_weights
|
||||
if isinstance(weights, str):
|
||||
weights = list(mx.load(weights).items())
|
||||
|
||||
if strict:
|
||||
new_weights = dict(weights)
|
||||
curr_weights = dict(tree_flatten(self.parameters()))
|
||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||
extras = " ".join(extras)
|
||||
raise ValueError(f"Received parameters not in model: {extras}.")
|
||||
if missing := (curr_weights.keys() - new_weights.keys()):
|
||||
missing = " ".join(missing)
|
||||
raise ValueError(f"Missing parameters: {missing}.")
|
||||
for k, v in curr_weights.items():
|
||||
v_new = new_weights[k]
|
||||
if not isinstance(v_new, mx.array):
|
||||
raise ValueError(
|
||||
"Expected mx.array but received "
|
||||
f"{type(v_new)} for parameter {k}"
|
||||
)
|
||||
if v_new.shape != v.shape:
|
||||
raise ValueError(
|
||||
f"Expected shape {v.shape} but received "
|
||||
f" shape {v_new.shape} for parameter {k}"
|
||||
)
|
||||
|
||||
self.update(tree_unflatten(weights))
|
||||
|
||||
def save_weights(self, file: str):
|
||||
"""
|
||||
Save the model's weights to a `.npz` file.
|
||||
Save the model's weights to a ``.npz`` file.
|
||||
"""
|
||||
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
||||
|
||||
@ -351,23 +420,26 @@ class Module(dict):
|
||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||
computing gradients for it.
|
||||
|
||||
This function is idempotent ie freezing a frozen model is a noop.
|
||||
This function is idempotent i.e. freezing a frozen model is a no-op.
|
||||
|
||||
For instance to only train the attention parameters from a transformer:
|
||||
Example:
|
||||
For instance to only train the attention parameters from a Transformer:
|
||||
|
||||
model = ...
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Transformer()
|
||||
model.freeze()
|
||||
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then freeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
submodules as well. Default: ``True``.
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be frozen otherwise all the parameters of a
|
||||
module. For instance freeze all biases by calling
|
||||
``module.freeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
||||
Default: ``False``.
|
||||
"""
|
||||
|
||||
def _freeze_impl(_, m):
|
||||
@ -401,21 +473,25 @@ class Module(dict):
|
||||
This function is idempotent ie unfreezing a model that is not frozen is
|
||||
a noop.
|
||||
|
||||
For instance to only train the biases one can do:
|
||||
Example:
|
||||
|
||||
model = ...
|
||||
For instance to only train the biases of a Transformer one can do:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Transformer()
|
||||
model.freeze()
|
||||
model.unfreeze(keys="bias")
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then unfreeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
submodules as well. Default: ``True``.
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be unfrozen otherwise all the parameters of a
|
||||
module. For instance unfreeze all biases by calling
|
||||
``module.unfreeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
||||
Default: ``False``.
|
||||
"""
|
||||
|
||||
def _unfreeze_impl(_, m):
|
||||
@ -432,10 +508,25 @@ class Module(dict):
|
||||
_unfreeze_impl("", self)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set the model in or out of training mode.
|
||||
|
||||
Training mode only applies to certain layers. For example
|
||||
:obj:`Dropout` applies a random mask in training mode, but is the
|
||||
identity in evaluation mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Indicate if the model should be in training or
|
||||
evaluation mode. Default: ``True``.
|
||||
"""
|
||||
|
||||
def _set_train(_, m):
|
||||
m._training = mode
|
||||
|
||||
self.apply_to_modules(_set_train)
|
||||
|
||||
def eval(self):
|
||||
"""Set the model to evaluation mode.
|
||||
|
||||
See :func:`train`.
|
||||
"""
|
||||
self.train(False)
|
||||
|
279
python/tests/test_losses.py
Normal file
279
python/tests/test_losses.py
Normal file
@ -0,0 +1,279 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestLosses(mlx_tests.MLXTestCase):
|
||||
def test_cross_entropy(self):
|
||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||
targets = mx.array([0, 1])
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
|
||||
expected_none = mx.array([0.0, 0.0])
|
||||
self.assertTrue(mx.array_equal(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertEqual(losses_mean, expected_mean)
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertEqual(losses_sum, expected_sum)
|
||||
|
||||
# Test cases with weights and no label smoothing
|
||||
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
|
||||
targets = mx.array([0, 1])
|
||||
weights = mx.array([1.0, 2.0])
|
||||
|
||||
# Reduction 'none'
|
||||
losses_none = nn.losses.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
weights=weights,
|
||||
reduction="none",
|
||||
)
|
||||
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
|
||||
self.assertTrue(
|
||||
np.allclose(losses_none, expected_none, atol=1e-5),
|
||||
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
|
||||
)
|
||||
|
||||
# Reduction 'mean'
|
||||
losses_mean = nn.losses.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
weights=weights,
|
||||
reduction="mean",
|
||||
)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertTrue(
|
||||
np.allclose(losses_mean, expected_mean, atol=1e-5),
|
||||
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
|
||||
)
|
||||
|
||||
# Reduction 'sum'
|
||||
losses_sum = nn.losses.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
weights=weights,
|
||||
reduction="sum",
|
||||
)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(
|
||||
np.allclose(losses_sum, expected_sum, atol=1e-5),
|
||||
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
|
||||
)
|
||||
|
||||
# Test case with equal weights and label smoothing > 0
|
||||
logits = mx.array(
|
||||
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
|
||||
)
|
||||
target = mx.array([2, 1, 0])
|
||||
|
||||
losses_none = nn.losses.cross_entropy(
|
||||
logits, target, label_smoothing=0.3, reduction="none"
|
||||
)
|
||||
expected_none = mx.array([1.29693, 1.38617, 1.48176])
|
||||
self.assertTrue(
|
||||
mx.allclose(expected_none, losses_none),
|
||||
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
|
||||
)
|
||||
|
||||
expected_mean = mx.mean(expected_none)
|
||||
losses_mean = nn.losses.cross_entropy(
|
||||
logits, target, label_smoothing=0.3, reduction="mean"
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(losses_mean, expected_mean),
|
||||
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
|
||||
)
|
||||
|
||||
expected_sum = mx.sum(expected_none)
|
||||
losses_sum = nn.losses.cross_entropy(
|
||||
logits, target, label_smoothing=0.3, reduction="sum"
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(losses_sum, expected_sum),
|
||||
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
|
||||
)
|
||||
|
||||
def test_l1_loss(self):
|
||||
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||
targets = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||
|
||||
# Expected result
|
||||
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
|
||||
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
|
||||
self.assertTrue(
|
||||
mx.array_equal(losses, expected_none),
|
||||
"Test failed for l1_loss --reduction='none'",
|
||||
)
|
||||
|
||||
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
|
||||
self.assertTrue(mx.array_equal(losses, expected_sum))
|
||||
|
||||
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
|
||||
self.assertTrue(mx.array_equal(losses, expected_mean))
|
||||
|
||||
def test_mse_loss(self):
|
||||
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||
targets = mx.array([0.7, 0.1, 0.8, 0.2])
|
||||
|
||||
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
|
||||
expected_mean = mx.mean(expected_none)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
|
||||
self.assertTrue(
|
||||
np.allclose(losses_none, expected_none, 1e-5),
|
||||
"Test case failed for mse_loss --reduction='none'",
|
||||
)
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
|
||||
self.assertEqual(
|
||||
losses_mean,
|
||||
expected_mean,
|
||||
"Test case failed for mse_loss --reduction='mean'",
|
||||
)
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
|
||||
self.assertEqual(
|
||||
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
|
||||
)
|
||||
|
||||
def test_smooth_l1_loss(self):
|
||||
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
|
||||
targets = mx.array([1.0, 2.0, 0.5, 2.5])
|
||||
beta = 1.0
|
||||
|
||||
# Expected results
|
||||
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
|
||||
expected_sum = mx.sum(expected_none)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
|
||||
# Test with reduction 'none'
|
||||
loss_none = nn.losses.smooth_l1_loss(
|
||||
predictions, targets, beta, reduction="none"
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.array_equal(loss_none, expected_none),
|
||||
"Test case failed for smooth_l1_loss --reduction='none'",
|
||||
)
|
||||
|
||||
# Test with reduction 'sum'
|
||||
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
|
||||
self.assertEqual(
|
||||
loss_sum,
|
||||
expected_sum,
|
||||
"Test case failed for smooth_l1_loss --reduction='sum'",
|
||||
)
|
||||
|
||||
# Test with reduction 'mean'
|
||||
loss_mean = nn.losses.smooth_l1_loss(
|
||||
predictions, targets, beta, reduction="mean"
|
||||
)
|
||||
self.assertEqual(
|
||||
loss_mean,
|
||||
expected_mean,
|
||||
"Test case failed for smooth_l1_loss --reduction='mean'",
|
||||
)
|
||||
|
||||
def test_nll_loss(self):
|
||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||
targets = mx.array([0, 1])
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.nll_loss(logits, targets, reduction="none")
|
||||
expected_none = mx.array([0.0, 0.0])
|
||||
self.assertTrue(mx.array_equal(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean")
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertEqual(losses_mean, expected_mean)
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum")
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertEqual(losses_sum, expected_sum)
|
||||
|
||||
def test_kl_div_loss(self):
|
||||
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
|
||||
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none")
|
||||
expected_none = mx.array([0.0, 0.831777])
|
||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean")
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum")
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
def test_triplet_loss(self):
|
||||
anchors = mx.array([[1, 2, 3], [1, 2, 3]])
|
||||
positives = mx.array([[4, 5, 6], [0, -1, 2]])
|
||||
negatives = mx.array([[7, 8, 9], [3, 2, 3]])
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.triplet_loss(
|
||||
anchors, positives, negatives, reduction="none"
|
||||
)
|
||||
expected_none = mx.array([0, 2.31662])
|
||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.triplet_loss(
|
||||
anchors, positives, negatives, reduction="mean"
|
||||
)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.triplet_loss(
|
||||
anchors, positives, negatives, reduction="sum"
|
||||
)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
def test_hinge_loss(self):
|
||||
inputs = mx.ones((2, 4))
|
||||
targets = mx.zeros((2, 4))
|
||||
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
|
||||
self.assertEqual(loss, 1.0)
|
||||
|
||||
def test_huber_loss(self):
|
||||
inputs = mx.ones((2, 4))
|
||||
targets = mx.zeros((2, 4))
|
||||
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
|
||||
self.assertEqual(loss, 0.5)
|
||||
|
||||
def test_log_cosh_loss(self):
|
||||
inputs = mx.ones((2, 4))
|
||||
targets = mx.zeros((2, 4))
|
||||
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -11,7 +11,112 @@ import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
class TestNN(mlx_tests.MLXTestCase):
|
||||
class TestBase(mlx_tests.MLXTestCase):
|
||||
def test_module_utilities(self):
|
||||
m = nn.Sequential(
|
||||
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
||||
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
|
||||
nn.Linear(10, 1),
|
||||
mx.sigmoid,
|
||||
)
|
||||
|
||||
children = m.children()
|
||||
self.assertTrue(isinstance(children, dict))
|
||||
self.assertEqual(len(children), 1)
|
||||
self.assertTrue(isinstance(children["layers"], list))
|
||||
self.assertEqual(len(children["layers"]), 4)
|
||||
self.assertEqual(children["layers"][3], {})
|
||||
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(flat_children), 3)
|
||||
|
||||
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(leaves), 4)
|
||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||
self.assertEqual(leaves[3][0], "layers.2")
|
||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||
|
||||
m.eval()
|
||||
|
||||
def assert_not_training(k, m):
|
||||
self.assertFalse(m.training)
|
||||
|
||||
m.apply_to_modules(assert_not_training)
|
||||
|
||||
m.train()
|
||||
|
||||
def assert_training(k, m):
|
||||
self.assertTrue(m.training)
|
||||
|
||||
m.apply_to_modules(assert_training)
|
||||
|
||||
def test_io(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
file = os.path.join(tdir.name, "model.npz")
|
||||
m.save_weights(file)
|
||||
m_load = make_model()
|
||||
m_load.load_weights(file)
|
||||
tdir.cleanup()
|
||||
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||
|
||||
def test_load_from_weights(self):
|
||||
m = nn.Linear(2, 2)
|
||||
|
||||
# Too few weights
|
||||
weights = [("weight", mx.ones((2, 2)))]
|
||||
with self.assertRaises(ValueError):
|
||||
m.load_weights(weights)
|
||||
|
||||
m.load_weights(weights, strict=False)
|
||||
self.assertTrue(mx.array_equal(m.weight, weights[0][1]))
|
||||
|
||||
# Wrong name
|
||||
with self.assertRaises(ValueError):
|
||||
m.load_weights([("weihgt", mx.ones((2, 2)))])
|
||||
|
||||
# Ok
|
||||
m.load_weights([("weihgt", mx.ones((2, 2)))], strict=False)
|
||||
|
||||
# Too many weights
|
||||
with self.assertRaises(ValueError):
|
||||
m.load_weights(
|
||||
[
|
||||
("weight", mx.ones((2, 2))),
|
||||
("bias", mx.ones((2,))),
|
||||
("bias2", mx.ones((2,))),
|
||||
]
|
||||
)
|
||||
|
||||
# Wrong shape
|
||||
with self.assertRaises(ValueError):
|
||||
m.load_weights(
|
||||
[
|
||||
("weight", mx.ones((2, 2))),
|
||||
("bias", mx.ones((2, 1))),
|
||||
]
|
||||
)
|
||||
|
||||
# Wrong type
|
||||
with self.assertRaises(ValueError):
|
||||
m.load_weights(
|
||||
[
|
||||
("weight", mx.ones((2, 2))),
|
||||
("bias", 3),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Identity()
|
||||
@ -31,272 +136,6 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
outputs = layer(inputs1, inputs2)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 6))
|
||||
|
||||
def test_cross_entropy(self):
|
||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||
targets = mx.array([0, 1])
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
|
||||
expected_none = mx.array([0.0, 0.0])
|
||||
self.assertTrue(mx.array_equal(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertEqual(losses_mean, expected_mean)
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertEqual(losses_sum, expected_sum)
|
||||
|
||||
# Test cases with weights and no label smoothing
|
||||
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
|
||||
targets = mx.array([0, 1])
|
||||
weights = mx.array([1.0, 2.0])
|
||||
|
||||
# Reduction 'none'
|
||||
losses_none = nn.losses.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
weights=weights,
|
||||
reduction="none",
|
||||
)
|
||||
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
|
||||
self.assertTrue(
|
||||
np.allclose(losses_none, expected_none, atol=1e-5),
|
||||
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
|
||||
)
|
||||
|
||||
# Reduction 'mean'
|
||||
losses_mean = nn.losses.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
weights=weights,
|
||||
reduction="mean",
|
||||
)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertTrue(
|
||||
np.allclose(losses_mean, expected_mean, atol=1e-5),
|
||||
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
|
||||
)
|
||||
|
||||
# Reduction 'sum'
|
||||
losses_sum = nn.losses.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
weights=weights,
|
||||
reduction="sum",
|
||||
)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(
|
||||
np.allclose(losses_sum, expected_sum, atol=1e-5),
|
||||
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
|
||||
)
|
||||
|
||||
# Test case with equal weights and label smoothing > 0
|
||||
logits = mx.array(
|
||||
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
|
||||
)
|
||||
target = mx.array([2, 1, 0])
|
||||
|
||||
losses_none = nn.losses.cross_entropy(
|
||||
logits, target, label_smoothing=0.3, reduction="none"
|
||||
)
|
||||
expected_none = mx.array([1.29693, 1.38617, 1.48176])
|
||||
self.assertTrue(
|
||||
mx.allclose(expected_none, losses_none),
|
||||
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
|
||||
)
|
||||
|
||||
expected_mean = mx.mean(expected_none)
|
||||
losses_mean = nn.losses.cross_entropy(
|
||||
logits, target, label_smoothing=0.3, reduction="mean"
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(losses_mean, expected_mean),
|
||||
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
|
||||
)
|
||||
|
||||
expected_sum = mx.sum(expected_none)
|
||||
losses_sum = nn.losses.cross_entropy(
|
||||
logits, target, label_smoothing=0.3, reduction="sum"
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(losses_sum, expected_sum),
|
||||
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
|
||||
)
|
||||
|
||||
def test_l1_loss(self):
|
||||
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||
targets = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||
|
||||
# Expected result
|
||||
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
|
||||
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
|
||||
self.assertTrue(
|
||||
mx.array_equal(losses, expected_none),
|
||||
"Test failed for l1_loss --reduction='none'",
|
||||
)
|
||||
|
||||
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
|
||||
self.assertTrue(mx.array_equal(losses, expected_sum))
|
||||
|
||||
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
|
||||
self.assertTrue(mx.array_equal(losses, expected_mean))
|
||||
|
||||
def test_mse_loss(self):
|
||||
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||
targets = mx.array([0.7, 0.1, 0.8, 0.2])
|
||||
|
||||
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
|
||||
expected_mean = mx.mean(expected_none)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
|
||||
self.assertTrue(
|
||||
np.allclose(losses_none, expected_none, 1e-5),
|
||||
"Test case failed for mse_loss --reduction='none'",
|
||||
)
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
|
||||
self.assertEqual(
|
||||
losses_mean,
|
||||
expected_mean,
|
||||
"Test case failed for mse_loss --reduction='mean'",
|
||||
)
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
|
||||
self.assertEqual(
|
||||
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
|
||||
)
|
||||
|
||||
def test_smooth_l1_loss(self):
|
||||
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
|
||||
targets = mx.array([1.0, 2.0, 0.5, 2.5])
|
||||
beta = 1.0
|
||||
|
||||
# Expected results
|
||||
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
|
||||
expected_sum = mx.sum(expected_none)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
|
||||
# Test with reduction 'none'
|
||||
loss_none = nn.losses.smooth_l1_loss(
|
||||
predictions, targets, beta, reduction="none"
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.array_equal(loss_none, expected_none),
|
||||
"Test case failed for smooth_l1_loss --reduction='none'",
|
||||
)
|
||||
|
||||
# Test with reduction 'sum'
|
||||
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
|
||||
self.assertEqual(
|
||||
loss_sum,
|
||||
expected_sum,
|
||||
"Test case failed for smooth_l1_loss --reduction='sum'",
|
||||
)
|
||||
|
||||
# Test with reduction 'mean'
|
||||
loss_mean = nn.losses.smooth_l1_loss(
|
||||
predictions, targets, beta, reduction="mean"
|
||||
)
|
||||
self.assertEqual(
|
||||
loss_mean,
|
||||
expected_mean,
|
||||
"Test case failed for smooth_l1_loss --reduction='mean'",
|
||||
)
|
||||
|
||||
def test_nll_loss(self):
|
||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||
targets = mx.array([0, 1])
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.nll_loss(logits, targets, reduction="none")
|
||||
expected_none = mx.array([0.0, 0.0])
|
||||
self.assertTrue(mx.array_equal(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean")
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertEqual(losses_mean, expected_mean)
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum")
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertEqual(losses_sum, expected_sum)
|
||||
|
||||
def test_kl_div_loss(self):
|
||||
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
|
||||
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none")
|
||||
expected_none = mx.array([0.0, 0.831777])
|
||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean")
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum")
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
def test_triplet_loss(self):
|
||||
anchors = mx.array([[1, 2, 3], [1, 2, 3]])
|
||||
positives = mx.array([[4, 5, 6], [0, -1, 2]])
|
||||
negatives = mx.array([[7, 8, 9], [3, 2, 3]])
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.triplet_loss(
|
||||
anchors, positives, negatives, reduction="none"
|
||||
)
|
||||
expected_none = mx.array([0, 2.31662])
|
||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.triplet_loss(
|
||||
anchors, positives, negatives, reduction="mean"
|
||||
)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.triplet_loss(
|
||||
anchors, positives, negatives, reduction="sum"
|
||||
)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
def test_gelu(self):
|
||||
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
|
||||
|
||||
# From: jax.nn.gelu(np.array(inputs), approximate=False)
|
||||
expected = np.array(
|
||||
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
||||
)
|
||||
|
||||
out = nn.GELU()(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out, expected))
|
||||
|
||||
# Crudely check the approximations
|
||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||
y = nn.gelu(x)
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||
|
||||
def test_group_norm(self):
|
||||
x = mx.arange(100, dtype=mx.float32)
|
||||
x = x.reshape(1, 10, 10, 1)
|
||||
@ -570,47 +409,24 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
y2 = m(x)
|
||||
self.assertTrue(mx.array_equal(y, y2))
|
||||
|
||||
def test_module_utilities(self):
|
||||
m = nn.Sequential(
|
||||
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
||||
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
|
||||
nn.Linear(10, 1),
|
||||
mx.sigmoid,
|
||||
def test_gelu(self):
|
||||
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
|
||||
|
||||
# From: jax.nn.gelu(np.array(inputs), approximate=False)
|
||||
expected = np.array(
|
||||
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
||||
)
|
||||
|
||||
children = m.children()
|
||||
self.assertTrue(isinstance(children, dict))
|
||||
self.assertEqual(len(children), 1)
|
||||
self.assertTrue(isinstance(children["layers"], list))
|
||||
self.assertEqual(len(children["layers"]), 4)
|
||||
self.assertEqual(children["layers"][3], {})
|
||||
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(flat_children), 3)
|
||||
out = nn.GELU()(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out, expected))
|
||||
|
||||
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(leaves), 4)
|
||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||
self.assertEqual(leaves[3][0], "layers.2")
|
||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||
|
||||
m.eval()
|
||||
|
||||
def assert_not_training(k, m):
|
||||
self.assertFalse(m.training)
|
||||
|
||||
m.apply_to_modules(assert_not_training)
|
||||
|
||||
m.train()
|
||||
|
||||
def assert_training(k, m):
|
||||
self.assertTrue(m.training)
|
||||
|
||||
m.apply_to_modules(assert_training)
|
||||
# Crudely check the approximations
|
||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||
y = nn.gelu(x)
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||
|
||||
def test_sin_pe(self):
|
||||
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||
@ -623,21 +439,6 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
|
||||
)
|
||||
|
||||
def test_io(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
file = os.path.join(tdir.name, "model.npz")
|
||||
m.save_weights(file)
|
||||
m_load = make_model()
|
||||
m_load.load_weights(file)
|
||||
tdir.cleanup()
|
||||
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||
|
||||
def test_relu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.relu(x)
|
||||
@ -787,24 +588,6 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
y = alibi(x.astype(mx.float16))
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_hinge_loss(self):
|
||||
inputs = mx.ones((2, 4))
|
||||
targets = mx.zeros((2, 4))
|
||||
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
|
||||
self.assertEqual(loss, 1.0)
|
||||
|
||||
def test_huber_loss(self):
|
||||
inputs = mx.ones((2, 4))
|
||||
targets = mx.zeros((2, 4))
|
||||
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
|
||||
self.assertEqual(loss, 0.5)
|
||||
|
||||
def test_log_cosh_loss(self):
|
||||
inputs = mx.ones((2, 4))
|
||||
targets = mx.zeros((2, 4))
|
||||
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
|
||||
|
||||
def test_dropout(self):
|
||||
x = mx.ones((2, 4))
|
||||
y = nn.Dropout(0.5)(x)
|
||||
|
Loading…
Reference in New Issue
Block a user