mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
default strict mode for module update and update_modules (#2239)
This commit is contained in:
@@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
x = mx.zeros((3,))
|
||||
mx.grad(loss_fn)(model)
|
||||
|
||||
def test_update(self):
|
||||
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||
|
||||
# Updating non-existent parameters
|
||||
with self.assertRaises(ValueError):
|
||||
updates = {"layers": [{"value": 0}]}
|
||||
m.update(updates)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
updates = {"layers": ["hello"]}
|
||||
m.update(updates)
|
||||
|
||||
# Wronge type
|
||||
with self.assertRaises(ValueError):
|
||||
updates = {"layers": [{"weight": "hi"}]}
|
||||
m.update(updates)
|
||||
|
||||
def test_update_modules(self):
|
||||
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||
|
||||
# Updating non-existent modules should not be allowed by default
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"values": [0, 1]})
|
||||
|
||||
# Update wrong types
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"layers": [0, 1]})
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.test = mx.array(1.0)
|
||||
self.list = [mx.array(1.0), mx.array(2.0)]
|
||||
|
||||
m = MyModule()
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"test": "hi"})
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"list": ["hi"]})
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Reference in New Issue
Block a user