Support Chaining for some of functionalities of nn.Module (#885) (#897)

* add chaining support for some of the functionalities of "nn.Module"

* reformat

* change the return types

* remove return types

* add return type with forward referencing

* add tests for chaining

* add name to contributors

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update docstring

* update docstrings

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
AmirHossein_Razlighi
2024-03-28 06:28:29 +03:30
committed by GitHub
parent f30b659291
commit d611251502
3 changed files with 52 additions and 10 deletions

View File

@@ -162,6 +162,16 @@ class TestBase(mlx_tests.MLXTestCase):
m.state["hello"] = "world"
self.assertEqual(m.state["hello"], "world")
def test_chaining(self):
m = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))
pre_freeze_num_params = len(m.parameters())
m.freeze().unfreeze()
self.assertEqual(len(m.parameters()), pre_freeze_num_params)
params_dict = m.parameters()
self.assertFalse(m.update(params_dict).eval()._training)
self.assertTrue(m.train()._training)
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):