default strict mode for module update and update_modules (#2239)

This commit is contained in:
Awni Hannun
2025-06-05 15:27:02 -07:00
committed by GitHub
parent 52dc8c8cd5
commit c763fe1be0
2 changed files with 78 additions and 12 deletions

View File

@@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
x = mx.zeros((3,))
mx.grad(loss_fn)(model)
def test_update(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent parameters
with self.assertRaises(ValueError):
updates = {"layers": [{"value": 0}]}
m.update(updates)
with self.assertRaises(ValueError):
updates = {"layers": ["hello"]}
m.update(updates)
# Wronge type
with self.assertRaises(ValueError):
updates = {"layers": [{"weight": "hi"}]}
m.update(updates)
def test_update_modules(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent modules should not be allowed by default
with self.assertRaises(ValueError):
m = m.update_modules({"values": [0, 1]})
# Update wrong types
with self.assertRaises(ValueError):
m = m.update_modules({"layers": [0, 1]})
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.test = mx.array(1.0)
self.list = [mx.array(1.0), mx.array(2.0)]
m = MyModule()
with self.assertRaises(ValueError):
m = m.update_modules({"test": "hi"})
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):