Fix optimizer reloading from checkpoint (#1329)

* fix optimizer reloading from checkpoint

* comment
This commit is contained in:
Awni Hannun
2024-08-15 07:33:23 -07:00
committed by GitHub
parent d0630ffe8c
commit ae5b5cabfd
2 changed files with 58 additions and 12 deletions

View File

@@ -48,8 +48,28 @@ class Optimizer:
>>> optimizer.state.keys()
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
"""
self._state.update(tree_map(lambda x: {}, parameters))
tree_map(self.init_single, parameters, self._state)
# Iniatilize the optimizer state to match the parameter state
def update_state(params, state):
if isinstance(params, (list, tuple)):
state = list(state)
for i in range(len(state)):
state[i] = update_state(params[i], state[i])
if len(state) != len(params):
state.extend(tree_map(lambda x: {}, params[len(state) :]))
return type(params)(state)
elif isinstance(params, dict):
for k, v in params.items():
if k not in state:
state[k] = tree_map(lambda x: {}, v)
else:
state[k] = update_state(v, state[k])
return state
else:
return state
update_state(parameters, self._state)
tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)
self._initialized = True
def init_single(self, parameter: mx.array, state: dict):
@@ -104,7 +124,7 @@ class Optimizer:
@state.setter
def state(self, state: dict):
self._initialized = True
self._initialized = False
self._state = state
@property