mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
loading empty list is ok when strict = false (#1834)
This commit is contained in:
parent
fe5987b81d
commit
ca305afdbe
@ -192,7 +192,8 @@ class Module(dict):
|
|||||||
f"shape {v_new.shape} for parameter {k}"
|
f"shape {v_new.shape} for parameter {k}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.update(tree_unflatten(weights))
|
if len(weights) != 0:
|
||||||
|
self.update(tree_unflatten(weights))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def save_weights(self, file: str):
|
def save_weights(self, file: str):
|
||||||
|
@ -167,6 +167,9 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Empty weights is ok if strict is false
|
||||||
|
m.load_weights([], strict=False)
|
||||||
|
|
||||||
def test_module_state(self):
|
def test_module_state(self):
|
||||||
m = nn.Linear(10, 1)
|
m = nn.Linear(10, 1)
|
||||||
m.state["hello"] = "world"
|
m.state["hello"] = "world"
|
||||||
|
Loading…
Reference in New Issue
Block a user