mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
* add chaining support for some of the functionalities of "nn.Module" * reformat * change the return types * remove return types * add return type with forward referencing * add tests for chaining * add name to contributors * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * update docstring * update docstrings --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:

committed by
GitHub

parent
f30b659291
commit
d611251502
@@ -162,6 +162,16 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
m.state["hello"] = "world"
|
||||
self.assertEqual(m.state["hello"], "world")
|
||||
|
||||
def test_chaining(self):
|
||||
m = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))
|
||||
pre_freeze_num_params = len(m.parameters())
|
||||
m.freeze().unfreeze()
|
||||
self.assertEqual(len(m.parameters()), pre_freeze_num_params)
|
||||
params_dict = m.parameters()
|
||||
|
||||
self.assertFalse(m.update(params_dict).eval()._training)
|
||||
self.assertTrue(m.train()._training)
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Reference in New Issue
Block a user