mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Fix optimizer reloading from checkpoint (#1329)
* fix optimizer reloading from checkpoint * comment
This commit is contained in:
@@ -48,8 +48,28 @@ class Optimizer:
|
||||
>>> optimizer.state.keys()
|
||||
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
||||
"""
|
||||
self._state.update(tree_map(lambda x: {}, parameters))
|
||||
tree_map(self.init_single, parameters, self._state)
|
||||
|
||||
# Iniatilize the optimizer state to match the parameter state
|
||||
def update_state(params, state):
|
||||
if isinstance(params, (list, tuple)):
|
||||
state = list(state)
|
||||
for i in range(len(state)):
|
||||
state[i] = update_state(params[i], state[i])
|
||||
if len(state) != len(params):
|
||||
state.extend(tree_map(lambda x: {}, params[len(state) :]))
|
||||
return type(params)(state)
|
||||
elif isinstance(params, dict):
|
||||
for k, v in params.items():
|
||||
if k not in state:
|
||||
state[k] = tree_map(lambda x: {}, v)
|
||||
else:
|
||||
state[k] = update_state(v, state[k])
|
||||
return state
|
||||
else:
|
||||
return state
|
||||
|
||||
update_state(parameters, self._state)
|
||||
tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)
|
||||
self._initialized = True
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
@@ -104,7 +124,7 @@ class Optimizer:
|
||||
|
||||
@state.setter
|
||||
def state(self, state: dict):
|
||||
self._initialized = True
|
||||
self._initialized = False
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user